diff --git a/img/Accuracy_Final_Models.png b/img/Accuracy_Final_Models.png index 26314f3..9679782 100644 Binary files a/img/Accuracy_Final_Models.png and b/img/Accuracy_Final_Models.png differ diff --git a/img/Accuracy_for_3_MultiOpt.png b/img/Accuracy_for_3_MultiOpt.png index 73ad21d..b403ddb 100644 Binary files a/img/Accuracy_for_3_MultiOpt.png and b/img/Accuracy_for_3_MultiOpt.png differ diff --git a/img/Schematic.png b/img/Schematic.png index 08759cd..6c01679 100644 Binary files a/img/Schematic.png and b/img/Schematic.png differ diff --git a/readme.md b/readme.md index 94a2cd3..ca4c447 100644 --- a/readme.md +++ b/readme.md @@ -1,25 +1,73 @@ -# Prompting: a zero-shot language model to process multiple-choice questions on synonyms +# SynPL: a zero-shot language prompt model to process multiple-choice questions on synonyms ![Transformers](https://img.shields.io/badge/%F0%9F%A4%97%20TRANSFORMERS-4.10.1-blue) -Multiple-choice questions are a classic section in exams. When taking a language test such as TOEFL®, some questions need you to select the "best" choice among a set of four options of words or phrases that are the closest meaning to a "keyword" in the context of a reading passage. +- [SynPL: a zero-shot language prompt model to process multiple-choice questions on synonyms](#synpl-a-zero-shot-language-prompt-model-to-process-multiple-choice-questions-on-synonyms) + - [Overview](#overview) + - [Background](#background) + - [Approaches](#approaches) + - [Results](#results) + - [Findings](#findings) + - [Datasets](#datasets) + - [Question structure](#question-structure) + - [*AutoModelForMultipleChoice* model](#automodelformultiplechoice-model) + - [Three input patterns](#three-input-patterns) + - [Finetune](#finetune) + - [Results](#results-1) + - [Fill-mask prompt model](#fill-mask-prompt-model) + - [Background](#background-1) + - [Prompt idea](#prompt-idea) + - [Finetune](#finetune-1) + - [Definition of accuracy](#definition-of-accuracy) + - [Data processing](#data-processing) + - [Results](#results-2) + - [MultiOpt VS Prompt](#multiopt-vs-prompt) + - [Results](#results-3) + - [Requirements](#requirements) + - [Model files](#model-files) ## Overview -I developed 2 kinds of language models to solve this problem. The first is to use the `AutoModelForMultipleChoice` pre-trained model from [🤗TRANSFORMERS](https://huggingface.co/transformers). The second is to build a prompt for the classic fill-mask model, so that this multiple-choice task can be formulated as a masked language modeling problem, which is what pre-trained models like BERT are designed for in the first place. +### Background + +Multiple-choice questions are a classic section in exams. When taking a language test such as TOEFL®, some synonym questions need you to select the "best" choice among a set of four options of words or phrases that is the closest meaning to a "keyword" in the context of a reading passage. + +Here my objective is to build language models to process this kind of questions automatically. + +### Approaches + +I developed 2 kinds of language models to solve this problem. -The figure below shows the schematic. +The first is to use the `AutoModelForMultipleChoice` pre-trained model from [🤗TRANSFORMERS](https://huggingface.co/transformers). This is a generic model with a multiple-choice head, which yields scores of inputs in a given selection. We select ID with max scores as the most plausible inputs. -Where: +The second is to build a prompt for the classic fill-mask model, so that this multiple-choice task can be formulated as a masked language modeling problem, which is what pre-trained models like BERT are designed for in the first place. + +[Figure 1](#fig1) shows the schematic. Where: - **MultiOpt_KO** model is finetuned *AutoModelForMultipleChoice* pre-trained BERT model with `Keyword [SEP] Options` pattern as input - **PromptBi** model is the classic fill-mask model with prompt -![Fig 1. Diagram of 3 input patterns](img/Schematic.png) + + +![Fig 1. Schematic for 2 models](img/Schematic.png) + +**Fig 1.** Schematic for 2 models + +### Results + +[Figure 2](#fig2) is the final result of these 2 models. Where **"Raw BERT"** is un-finetuned *AutoModelForMultipleChoice* pre-trained BERT model. + + + +![Fig 2. Final result of 2 models](img/Accuracy_Final_Models.png) -And this is the final result of these 2 models. Where **"Raw BERT"** is un-finetuned *AutoModelForMultipleChoice* pre-trained BERT model. +**Fig 2.** Final result of 2 models -![Fig 1. Diagram of 3 input patterns](img/Accuracy_Final_Models.png) +### Findings + +- Same as humans, models don't need contexts for this kind of synonym questions +- Prompt has the best performance in terms of accuracy, input conciseness, and the amount of training data needed (Because it can be finetuned on other datasets first). +- To my best knowledge, prompt models are a novel way of language models for these kinds of problems. ## Datasets @@ -47,13 +95,13 @@ A typical example of synonym multiple-choice questions from TOEFL® test would b It consists of 4 parts: **Context**, **Question**, 4 **Options**, and the **Answer**. We can transform those strings via Python's [`re`](https://docs.python.org/3/library/re.html) library into a structured data frame. -| ID | Context | Question | Opt1 | Opt2 | Opt3 | Opt4 | Ans | -|-----:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------------------------------|:----------|:----------|:--------------|:------------|:------| -| 0 | Most of these leaders were involved in public ... | The word "representative" is closest in meanin... | typical | satisfied | supportive | distinctive | A | -| 1 | In the United States, Louis Comfort Tiffany (1... | The word "prized" is closest in meaning to whi... | valued | universal | uncommon | preserved | A | -| 2 | The Art Nouveau style was a major force in the... | The word "overtaken" is closest in meaning to ... | surpassed | inclined | expressed | applied | A | -| 3 | During most of their lives, surge glaciers beh... | The word "intervals" is closest in meaning to ... | records | speeds | distances | periods | D | -| 4 | The increasing water pressure under the glacie... | The word "freeing" is closest in meaning to wh... | pushing | releasing | strengthening | draining | B | +| ID | Context | Question | Opt1 | Opt2 | Opt3 | Opt4 | Ans | +| ---: | :------------------------------------------------ | :------------------------------------------------ | :-------- | :-------- | :------------ | :---------- | :--- | +| 0 | Most of these leaders were involved in public ... | The word "representative" is closest in meanin... | typical | satisfied | supportive | distinctive | A | +| 1 | In the United States, Louis Comfort Tiffany (1... | The word "prized" is closest in meaning to whi... | valued | universal | uncommon | preserved | A | +| 2 | The Art Nouveau style was a major force in the... | The word "overtaken" is closest in meaning to ... | surpassed | inclined | expressed | applied | A | +| 3 | During most of their lives, surge glaciers beh... | The word "intervals" is closest in meaning to ... | records | speeds | distances | periods | D | +| 4 | The increasing water pressure under the glacie... | The word "freeing" is closest in meaning to wh... | pushing | releasing | strengthening | draining | B | ## *AutoModelForMultipleChoice* model @@ -69,9 +117,9 @@ There are 3 input patterns we can try to build and then feed into the model: 1. `Keyword [SEP] Options` - **MultiOpt_KO** 1. `Keyword sentence [SEP] Option sentences` - **MultiOpt_KsOs** -![Fig 1. Diagram of 3 input patterns](img/Three_Input_Patterns_MultiOpt.png) +![Fig 3. Diagram of 3 input patterns](img/Three_Input_Patterns_MultiOpt.png) -**Fig 1.** Diagram of 3 input patterns +**Fig 3.** Diagram of 3 input patterns. *Where **C** stands for **Context**, **K** stands for **Keyword**, **O** stands for **Options**, **Ks** stands for **Keyword sentence** and **Os** stands for **Option sentences**.* ### Finetune @@ -79,25 +127,25 @@ I divided each dataset 70% as train data and 30% as test data with the same rand ### Results -[Table 1](#tab1) and [Figure 2](#fig2) show the mean accuracies for 3 models. +[Table 1](#tab1) and [Figure 4](#fig4) show the mean accuracies for 3 models. **Table 1.** Evaluation Metrics for 3 *AutoModelForMultipleChoice* Models -| Model | Accuracy | Precision | Recall | F1 | AP | -|-------|----------|--|--|--|--| -| MultiOpt_CKO | 43.07% | 0.4329 | 0.4314 | 0.4306 | 0.4335| -| MultiOpt_KO | 72.26% | 0.7242 | 0.7211 | 0.7212 | 0.8751| -| MultiOpt_KsOs | 34.31% | 0.3428 | 0.3418 | 0.3415 | 0.5057| +| Model | Accuracy | Precision | Recall | F1 | AP | +| ------------- | -------- | --------- | ------ | ------ | ------ | +| MultiOpt_CKO | 46.16% | 0.4503 | 0.5186 | 0.3891 | 0.3520 | +| MultiOpt_KO | 72.17% | 0.7095 | 0.7230 | 0.7264 | 0.8902 | +| MultiOpt_KsOs | 35.23% | 0.3393 | 0.3490 | 0.3468 | 0.4865 | **MultiOpt_KO** (the *Keyword [SEP] Options* pattern) got the best result. Not only has it the highest accuracy but also is the most robust one that achieved almost consistent prediction performance when trained from different initializations. - + -![Fig 2. Accuracy(%) of 3 Models](img/Accuracy_for_3_MultiOpt.png) +![Fig 4. Accuracy(%) of 3 Models](img/Accuracy_for_3_MultiOpt.png) -**Fig 2.** Accuracy(%) of 3 Models +**Fig 4.** Accuracy(%) of 3 Models It actually adheres to our intuitions that when doing TOEFL® synonym multiple-choice questions, most of the time we don't have to read the context. By only examining the keyword and options, we can still select the best choice out. @@ -113,9 +161,9 @@ Prompts are some small templates inserted into the inputs, with the aim that tas For example, if you get a text classification problem to grade a movie review "The drama discloses nothing", then you can insert a text "It was ____" into the end of this review. After executing your model, you just need to examine logit scores of "terrible" and "great", so that you can determine whether this review is positive or negative. -![Fig 3. A typical example of using a prompt](img/Prompt_Example.png) +![Fig 5. A typical example of using a prompt](img/Prompt_Example.png) -**Fig 3.** A typical example of using a prompt +**Fig 5.** A typical example of using a prompt You can read [Gao's post](https://thegradient.pub/prompting/) for more infomation. @@ -131,13 +179,13 @@ This prompt can be performed like this on the TOEFL® dataset: ### Finetune -We still need to finetune the model with this prompt, because we should let the model know that it's facing a synonym problem. See [Figure 4](#fig4) and [Table 2](#tab2) for how this step can boost its performance. +We still need to finetune the model with this prompt, because we should let the model know that it's facing a synonym problem. See [Figure 6](#fig6) and [Table 2](#tab2) for how this step can boost its performance. - + -![Fig 4. An example of before and after finetuning](img/After_Before_Finetune.png) +![Fig 6. An example of before and after finetuning](img/After_Before_Finetune.png) -**Fig 4.** An example of before and after finetuning. +**Fig 6.** An example of before and after finetuning. *Here you can see after finetuning, its output words are more like synonyms.* ### Definition of accuracy @@ -162,36 +210,36 @@ In conclusion, there are 3 models in this part. **Table 2.** Accuracy of prompt models -| Model |Top 1 |Top 2 |Top 3 |Top 5 |Top 10 | -|:---------|:------|:------|:------|:------|:------| -|Raw_BERT |5.51% |9.68% |12.24% |15.47% |20.48% | -|Prompt_Uni |12.13% |33.12% |44.34% |55.39% |65.13% | -|Prompt_Bi |13.01% |43.24% |53.12% |64.93% |74.12% | +| Model | Top 1 | Top 2 | Top 3 | Top 5 | Top 10 | +| :--------- | :----- | :----- | :----- | :----- | :----- | +| Raw_BERT | 5.51% | 9.68% | 12.24% | 15.47% | 20.48% | +| Prompt_Uni | 12.13% | 33.12% | 44.34% | 55.39% | 65.13% | +| Prompt_Bi | 13.01% | 43.24% | 53.12% | 64.93% | 74.12% | - + -![Fig 5. Accuracy of prompt models](img/Accuracy_prompt.png) +![Fig 7. Accuracy of prompt models](img/Accuracy_prompt.png) -**Fig 5.** Accuracy of prompt models +**Fig 7.** Accuracy of prompt models Below are some other metrics. -| Model | N | Recall | F1 | -|:-----------|----:|---------:|-------:| -| Raw_BERT | 1 | 0.0275 | 0.0522 | -| Prompt_Uni | 1 | 0.0587 | 0.1051 | -| Prompt_Bi | 1 | 0.0662 | 0.117 | -| Raw_BERT | 2 | 0.0484 | 0.0883 | -| Prompt_Uni | 2 | 0.1661 | 0.2494 | -| Prompt_Bi | 2 | 0.2101 | 0.2958 | -| Raw_BERT | 3 | 0.0612 | 0.1091 | -| Prompt_Uni | 3 | 0.2179 | 0.3035 | -| Prompt_Bi | 3 | 0.2643 | 0.3458 | -| Raw_BERT | 5 | 0.0774 | 0.134 | -| Prompt_Uni | 5 | 0.2755 | 0.3552 | -| Prompt_Bi | 5 | 0.3214 | 0.3913 | -| Raw_BERT | 10 | 0.1024 | 0.17 | -| Prompt_Uni | 10 | 0.3367 | 0.4024 | -| Prompt_Bi | 10 | 0.3745 | 0.4283 | +| Model | N | Recall | F1 | +| :--------- | ---: | -----: | -----: | +| Raw_BERT | 1 | 0.0275 | 0.0522 | +| Prompt_Uni | 1 | 0.0587 | 0.1051 | +| Prompt_Bi | 1 | 0.0662 | 0.117 | +| Raw_BERT | 2 | 0.0484 | 0.0883 | +| Prompt_Uni | 2 | 0.1661 | 0.2494 | +| Prompt_Bi | 2 | 0.2101 | 0.2958 | +| Raw_BERT | 3 | 0.0612 | 0.1091 | +| Prompt_Uni | 3 | 0.2179 | 0.3035 | +| Prompt_Bi | 3 | 0.2643 | 0.3458 | +| Raw_BERT | 5 | 0.0774 | 0.134 | +| Prompt_Uni | 5 | 0.2755 | 0.3552 | +| Prompt_Bi | 5 | 0.3214 | 0.3913 | +| Raw_BERT | 10 | 0.1024 | 0.17 | +| Prompt_Uni | 10 | 0.3367 | 0.4024 | +| Prompt_Bi | 10 | 0.3745 | 0.4283 | ## MultiOpt VS Prompt @@ -207,17 +255,17 @@ These models all use the same test data, which means data where all phrases alon ### Results -[Table 3](#tab3) shows how these models perform on the TOEFL® dataset. +[Table 3](#tab3) and [Figure 2](#fig2) shows how these models perform on the TOEFL® dataset. **Table 3.** Metrics of 3 final models -| Model | Accuracy | Precision | Recall | F1 | AP | -|:-------|:----------|:--|:--|:--|:--| -|Raw_BERT |40.88% |0.4034|0.4066|0.4041|0.5956| -|MultiOpt_KO |72.99% |0.7355|0.7291|0.7287|0.887| -|Prompt_Bi |89.78% |0.9014|0.9001|0.9001|0.9735| +| Model | Accuracy | Precision | Recall | F1 | AP | +| :---------- | :------- | :-------- | :----- | :----- | :----- | +| Raw_BERT | 40.88% | 0.4034 | 0.4066 | 0.4041 | 0.5956 | +| MultiOpt_KO | 73.18% | 0.7385 | 0.7217 | 0.7341 | 0.8894 | +| Prompt_Bi | 89.69% | 0.9035 | 0.8977 | 0.9020 | 0.9683 | ## Requirements diff --git a/src/Prompt_Bi.ipynb b/src/Prompt_Bi.ipynb new file mode 100644 index 0000000..2fa8a64 --- /dev/null +++ b/src/Prompt_Bi.ipynb @@ -0,0 +1,1848 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Prompt_Bi.ipynb", + "provenance": [], + "collapsed_sections": [ + "T9Yzq7FXO4V0" + ] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "ywMZXOrtMlDe" + }, + "source": [ + "*This version has no duplicated words in test dataset*" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "iNtLy60GSiBb" + }, + "source": [ + "! pip install datasets transformers --upgrade" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AUcrPhh5W8n3" + }, + "source": [ + "# Extract synonyms that the tokenizer supports\n", + "Load the *Worndnet* data" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ls7wkjJjWptt", + "outputId": "d7df07f6-ca78-47d2-d827-b673eca55292" + }, + "source": [ + "import pandas as pd\n", + "from nltk.corpus import wordnet as wn\n", + "import nltk\n", + "nltk.download('wordnet')" + ], + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[nltk_data] Downloading package wordnet to /root/nltk_data...\n", + "[nltk_data] Package wordnet is already up-to-date!\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "True" + ] + }, + "metadata": {}, + "execution_count": 2 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GUgysOdZCKgo" + }, + "source": [ + "Load the Tokenizer" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "F9cjvuPlCEMW", + "outputId": "8bc346e7-3a4b-4d10-af13-f03ac1d6fafb" + }, + "source": [ + "from transformers import AutoTokenizer\n", + "\n", + "model_checkpoint = \"bert-base-uncased\"\n", + "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n", + "tokenizer" + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "PreTrainedTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_len=512, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})" + ] + }, + "metadata": {}, + "execution_count": 3 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gh864RlSCTi1" + }, + "source": [ + "Get through the words in the tokenizer" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "VIrHOVJMCYhe", + "outputId": "d0523535-dd28-4a4a-b582-f9dbb941c33c" + }, + "source": [ + "len(tokenizer.get_vocab())" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "30522" + ] + }, + "metadata": {}, + "execution_count": 4 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "E0-kDLaHXECc" + }, + "source": [ + "synonyms = []\n", + "\n", + "for synset in list(wn.all_synsets()):\n", + " temp = synset.lemma_names()[:]\n", + " if len(temp)<=1:\n", + " continue\n", + " temp2 = []\n", + " for i in temp:\n", + " if len(tokenizer(i)['input_ids'])==3:\n", + " temp2.append(i)\n", + " if len(temp2)<=1:\n", + " continue \n", + " synonyms.append(temp2)" + ], + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "aSzPwBdYHqQn", + "outputId": "44b05b5d-c0ad-44ee-ab46-2af4cefbb20a" + }, + "source": [ + "len(synonyms)" + ], + "execution_count": 6, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "7747" + ] + }, + "metadata": {}, + "execution_count": 6 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZRErscwjkkGZ" + }, + "source": [ + "Extract words" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "MwbBx9UQXfUP" + }, + "source": [ + "from itertools import combinations\n", + "\n", + "df = []\n", + "for line in synonyms:\n", + " comb = combinations(line, 2)\n", + " for i in list(comb):\n", + " temp1 = line[:]\n", + " temp1.remove(i[0])\n", + " temp2 = line[:]\n", + " temp2.remove(i[1])\n", + " df.append({\n", + " 'word1': i[0],\n", + " 'word2': i[1],\n", + " 'synonyms1': temp1,\n", + " 'synonyms2': temp2,\n", + " })" + ], + "execution_count": 7, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "jkyZaW0clChj", + "outputId": "6109ddf1-f4d8-4298-94c9-39bca8a35d22" + }, + "source": [ + "df[:10]" + ], + "execution_count": 8, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[{'synonyms1': ['shortened'],\n", + " 'synonyms2': ['cut'],\n", + " 'word1': 'cut',\n", + " 'word2': 'shortened'},\n", + " {'synonyms1': ['comparative'],\n", + " 'synonyms2': ['relative'],\n", + " 'word1': 'relative',\n", + " 'word2': 'comparative'},\n", + " {'synonyms1': ['tangible'],\n", + " 'synonyms2': ['real'],\n", + " 'word1': 'real',\n", + " 'word2': 'tangible'},\n", + " {'synonyms1': ['rich'],\n", + " 'synonyms2': ['ample'],\n", + " 'word1': 'ample',\n", + " 'word2': 'rich'},\n", + " {'synonyms1': ['faithful'],\n", + " 'synonyms2': ['close'],\n", + " 'word1': 'close',\n", + " 'word2': 'faithful'},\n", + " {'synonyms1': ['outside'],\n", + " 'synonyms2': ['away'],\n", + " 'word1': 'away',\n", + " 'word2': 'outside'},\n", + " {'synonyms1': ['incorrect', 'wrong'],\n", + " 'synonyms2': ['faulty', 'wrong'],\n", + " 'word1': 'faulty',\n", + " 'word2': 'incorrect'},\n", + " {'synonyms1': ['incorrect', 'wrong'],\n", + " 'synonyms2': ['faulty', 'incorrect'],\n", + " 'word1': 'faulty',\n", + " 'word2': 'wrong'},\n", + " {'synonyms1': ['faulty', 'wrong'],\n", + " 'synonyms2': ['faulty', 'incorrect'],\n", + " 'word1': 'incorrect',\n", + " 'word2': 'wrong'},\n", + " {'synonyms1': ['recognized', 'recognised'],\n", + " 'synonyms2': ['accepted', 'recognised'],\n", + " 'word1': 'accepted',\n", + " 'word2': 'recognized'}]" + ] + }, + "metadata": {}, + "execution_count": 8 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "IDgv7birH8G2", + "outputId": "fcb0d498-e132-4aaa-b8f9-492699ef7cdb" + }, + "source": [ + "len(df)" + ], + "execution_count": 9, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "15586" + ] + }, + "metadata": {}, + "execution_count": 9 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "V0xhUtXOJp4U" + }, + "source": [ + "# Generate sentences\n", + "## Version 1: unidirection" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "06_zFzx6pH75" + }, + "source": [ + "# datasets = []\n", + "\n", + "# for line in df:\n", + "# sen = f'{line[0]} is close in meaning to {line[1]}.'\n", + "# word = line[1]\n", + "# datasets.append({'sen':sen, 'word':word})\n", + "\n", + "# datasets[:10]" + ], + "execution_count": 10, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "s7FUggt0odpM" + }, + "source": [ + "# len(datasets)" + ], + "execution_count": 11, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "489q8m17M3Wd" + }, + "source": [ + "## Version 2: bi-direction" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "H1zf2iPeNA2-", + "outputId": "ae7b5431-19d1-4613-dc08-813b802bf05a" + }, + "source": [ + "datasets = []\n", + "\n", + "for line in df:\n", + " sen = f'{line[\"word1\"]} is close in meaning to {line[\"word2\"]}.'\n", + " datasets.append({'sen':sen, 'word1':line[\"word1\"], 'word2':line[\"word2\"], 'synonyms':line['synonyms1']})\n", + " sen = f'{line[\"word2\"]} is close in meaning to {line[\"word1\"]}.'\n", + " datasets.append({'sen':sen, 'word1':line[\"word2\"], 'word2':line[\"word1\"], 'synonyms':line['synonyms2']})\n", + "\n", + "datasets[:20]" + ], + "execution_count": 12, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[{'sen': 'cut is close in meaning to shortened.',\n", + " 'synonyms': ['shortened'],\n", + " 'word1': 'cut',\n", + " 'word2': 'shortened'},\n", + " {'sen': 'shortened is close in meaning to cut.',\n", + " 'synonyms': ['cut'],\n", + " 'word1': 'shortened',\n", + " 'word2': 'cut'},\n", + " {'sen': 'relative is close in meaning to comparative.',\n", + " 'synonyms': ['comparative'],\n", + " 'word1': 'relative',\n", + " 'word2': 'comparative'},\n", + " {'sen': 'comparative is close in meaning to relative.',\n", + " 'synonyms': ['relative'],\n", + " 'word1': 'comparative',\n", + " 'word2': 'relative'},\n", + " {'sen': 'real is close in meaning to tangible.',\n", + " 'synonyms': ['tangible'],\n", + " 'word1': 'real',\n", + " 'word2': 'tangible'},\n", + " {'sen': 'tangible is close in meaning to real.',\n", + " 'synonyms': ['real'],\n", + " 'word1': 'tangible',\n", + " 'word2': 'real'},\n", + " {'sen': 'ample is close in meaning to rich.',\n", + " 'synonyms': ['rich'],\n", + " 'word1': 'ample',\n", + " 'word2': 'rich'},\n", + " {'sen': 'rich is close in meaning to ample.',\n", + " 'synonyms': ['ample'],\n", + " 'word1': 'rich',\n", + " 'word2': 'ample'},\n", + " {'sen': 'close is close in meaning to faithful.',\n", + " 'synonyms': ['faithful'],\n", + " 'word1': 'close',\n", + " 'word2': 'faithful'},\n", + " {'sen': 'faithful is close in meaning to close.',\n", + " 'synonyms': ['close'],\n", + " 'word1': 'faithful',\n", + " 'word2': 'close'},\n", + " {'sen': 'away is close in meaning to outside.',\n", + " 'synonyms': ['outside'],\n", + " 'word1': 'away',\n", + " 'word2': 'outside'},\n", + " {'sen': 'outside is close in meaning to away.',\n", + " 'synonyms': ['away'],\n", + " 'word1': 'outside',\n", + " 'word2': 'away'},\n", + " {'sen': 'faulty is close in meaning to incorrect.',\n", + " 'synonyms': ['incorrect', 'wrong'],\n", + " 'word1': 'faulty',\n", + " 'word2': 'incorrect'},\n", + " {'sen': 'incorrect is close in meaning to faulty.',\n", + " 'synonyms': ['faulty', 'wrong'],\n", + " 'word1': 'incorrect',\n", + " 'word2': 'faulty'},\n", + " {'sen': 'faulty is close in meaning to wrong.',\n", + " 'synonyms': ['incorrect', 'wrong'],\n", + " 'word1': 'faulty',\n", + " 'word2': 'wrong'},\n", + " {'sen': 'wrong is close in meaning to faulty.',\n", + " 'synonyms': ['faulty', 'incorrect'],\n", + " 'word1': 'wrong',\n", + " 'word2': 'faulty'},\n", + " {'sen': 'incorrect is close in meaning to wrong.',\n", + " 'synonyms': ['faulty', 'wrong'],\n", + " 'word1': 'incorrect',\n", + " 'word2': 'wrong'},\n", + " {'sen': 'wrong is close in meaning to incorrect.',\n", + " 'synonyms': ['faulty', 'incorrect'],\n", + " 'word1': 'wrong',\n", + " 'word2': 'incorrect'},\n", + " {'sen': 'accepted is close in meaning to recognized.',\n", + " 'synonyms': ['recognized', 'recognised'],\n", + " 'word1': 'accepted',\n", + " 'word2': 'recognized'},\n", + " {'sen': 'recognized is close in meaning to accepted.',\n", + " 'synonyms': ['accepted', 'recognised'],\n", + " 'word1': 'recognized',\n", + " 'word2': 'accepted'}]" + ] + }, + "metadata": {}, + "execution_count": 12 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Rnkkb8ULNA2_", + "outputId": "ae429191-b2b1-4703-820f-bcdf9acaf35e" + }, + "source": [ + "len(datasets)" + ], + "execution_count": 13, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "31172" + ] + }, + "metadata": {}, + "execution_count": 13 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OPsjpSaKTGp0" + }, + "source": [ + "# Creat datasets" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "oSBJvjqhPTXA" + }, + "source": [ + "from numpy import random, logical_not\n", + "\n", + "random.seed(123)\n", + "\n", + "# Random pick some words as test data\n", + "datasets_df = pd.DataFrame(datasets)\n", + "temp = datasets_df['word1'].unique()\n", + "test_words = random.choice(temp, replace=False, size=round(0.3*temp.size))\n", + "\n", + "# Extract those words\n", + "flag = datasets_df['word1'].isin(test_words)\n", + "test = datasets_df.loc[flag,:].reset_index(drop=True)\n", + "train = datasets_df.loc[logical_not(flag),:].reset_index(drop=True)" + ], + "execution_count": 14, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "QCeNYop9Pr9Q", + "outputId": "d1985ef3-5196-4382-f93d-b7c7d3b39e26" + }, + "source": [ + "print(test_words.size)\n", + "print(train.shape[0])" + ], + "execution_count": 15, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "2496\n", + "21952\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "id": "sN3aqodzTLo6", + "outputId": "9315cb77-5d49-4202-e1b4-8c14962b8a7b" + }, + "source": [ + "test.head()" + ], + "execution_count": 16, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
senword1word2synonyms
0real is close in meaning to tangible.realtangible[tangible]
1ample is close in meaning to rich.amplerich[rich]
2faithful is close in meaning to close.faithfulclose[close]
3away is close in meaning to outside.awayoutside[outside]
4outside is close in meaning to away.outsideaway[away]
\n", + "
" + ], + "text/plain": [ + " sen word1 word2 synonyms\n", + "0 real is close in meaning to tangible. real tangible [tangible]\n", + "1 ample is close in meaning to rich. ample rich [rich]\n", + "2 faithful is close in meaning to close. faithful close [close]\n", + "3 away is close in meaning to outside. away outside [outside]\n", + "4 outside is close in meaning to away. outside away [away]" + ] + }, + "metadata": {}, + "execution_count": 16 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "id": "td4cRVffYpZl", + "outputId": "c8a18ffb-efd2-4aee-8244-d5b79554ec96" + }, + "source": [ + "train.head()" + ], + "execution_count": 17, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
senword1word2synonyms
0cut is close in meaning to shortened.cutshortened[shortened]
1shortened is close in meaning to cut.shortenedcut[cut]
2relative is close in meaning to comparative.relativecomparative[comparative]
3comparative is close in meaning to relative.comparativerelative[relative]
4tangible is close in meaning to real.tangiblereal[real]
\n", + "
" + ], + "text/plain": [ + " sen ... synonyms\n", + "0 cut is close in meaning to shortened. ... [shortened]\n", + "1 shortened is close in meaning to cut. ... [cut]\n", + "2 relative is close in meaning to comparative. ... [comparative]\n", + "3 comparative is close in meaning to relative. ... [relative]\n", + "4 tangible is close in meaning to real. ... [real]\n", + "\n", + "[5 rows x 4 columns]" + ] + }, + "metadata": {}, + "execution_count": 17 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8-kMCQzbZnI5" + }, + "source": [ + "Clear test dataset, word1 can only appear once" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 677 + }, + "id": "bEO_Xd0PzpOm", + "outputId": "0f6b19d5-7d1e-4dba-8535-ba9fb41b9cc9" + }, + "source": [ + "test = test.groupby('word1').agg({'sen':'first', 'synonyms': sum, 'word2':'first'}).reset_index()\n", + "\n", + "def get_syn(row):\n", + " synonyms = []\n", + " for syn in wn.synsets(row):\n", + " for l in syn.lemmas():\n", + " if len(tokenizer(l.name())['input_ids'])==3:\n", + " synonyms.append(l.name())\n", + " synonyms = list(set(synonyms))\n", + " synonyms.remove(row)\n", + " return synonyms\n", + "\n", + "test['synonyms'] = test['word1'].apply(lambda x: get_syn(x))\n", + "\n", + "test.iloc[250:270,:]" + ], + "execution_count": 18, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
word1sensynonymsword2
250agriculturalagricultural is close in meaning to agrarian.[farming, agrarian]agrarian
251aidedaided is close in meaning to assisted.[aid, help, assisted, assist]assisted
252ainain is close in meaning to own.[own]own
253airfieldairfield is close in meaning to field.[field]field
254airliftairlift is close in meaning to lift.[lift]lift
255airplaneairplane is close in meaning to plane.[plane]plane
256akinakin is close in meaning to kin.[kin]kin
257alignalign is close in meaning to adjust.[adjust, array, coordinate]adjust
258alivealive is close in meaning to live.[awake, active, alert, live, animated]live
259alliancealliance is close in meaning to confederation.[bond, alignment, coalition, confederation]confederation
260alloyalloy is close in meaning to metal.[metal]metal
261alongsidealongside is close in meaning to aboard.[aboard]aboard
262alternativelyalternatively is close in meaning to instead.[instead]instead
263altogetheraltogether is close in meaning to wholly.[raw, totally, wholly, all, completely, whole,...wholly
264aluminumaluminum is close in meaning to aluminium.[aluminium, Al]aluminium
265amalgamationamalgamation is close in meaning to merger.[merger, uniting]merger
266amazedamazed is close in meaning to astonished.[beat, pose, gravel, get, astonished, puzzle, ...astonished
267amazinglyamazingly is close in meaning to surprisingly.[surprisingly]surprisingly
268amendamend is close in meaning to remedy.[better, remedy, improve, repair]remedy
269amountamount is close in meaning to measure.[total, measure, quantity, number, come, sum]measure
\n", + "
" + ], + "text/plain": [ + " word1 ... word2\n", + "250 agricultural ... agrarian\n", + "251 aided ... assisted\n", + "252 ain ... own\n", + "253 airfield ... field\n", + "254 airlift ... lift\n", + "255 airplane ... plane\n", + "256 akin ... kin\n", + "257 align ... adjust\n", + "258 alive ... live\n", + "259 alliance ... confederation\n", + "260 alloy ... metal\n", + "261 alongside ... aboard\n", + "262 alternatively ... instead\n", + "263 altogether ... wholly\n", + "264 aluminum ... aluminium\n", + "265 amalgamation ... merger\n", + "266 amazed ... astonished\n", + "267 amazingly ... surprisingly\n", + "268 amend ... remedy\n", + "269 amount ... measure\n", + "\n", + "[20 rows x 4 columns]" + ] + }, + "metadata": {}, + "execution_count": 18 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "PTW_RZ9MsobY" + }, + "source": [ + "from datasets import DatasetDict, Dataset\n", + "\n", + "datasets = DatasetDict({\n", + " 'train': Dataset.from_pandas(train),\n", + " 'test': Dataset.from_pandas(test)})" + ], + "execution_count": 19, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "WyCJmzA8whUs", + "outputId": "10ba9093-7780-4e60-f897-436e625355da" + }, + "source": [ + "datasets" + ], + "execution_count": 20, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "DatasetDict({\n", + " train: Dataset({\n", + " features: ['sen', 'word1', 'word2', 'synonyms'],\n", + " num_rows: 21952\n", + " })\n", + " test: Dataset({\n", + " features: ['word1', 'sen', 'synonyms', 'word2'],\n", + " num_rows: 2496\n", + " })\n", + "})" + ] + }, + "metadata": {}, + "execution_count": 20 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bT48GpTXUS-j" + }, + "source": [ + "# Tokenize" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "qKBVp1snG3DY" + }, + "source": [ + "def findindex(seq, subseq):\n", + " # brute-force approach O(n*m)\n", + " # Usage: findindex([4,3,'a',5,6], [5,6])\n", + " i, n, m = -1, len(seq), len(subseq)\n", + " try:\n", + " while True:\n", + " i = seq.index(subseq[0], i + 1, n - m + 1)\n", + " if subseq == seq[i:i + m]:\n", + " return i\n", + " except ValueError:\n", + " return -1" + ], + "execution_count": 21, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "N_vaTk3jylbn" + }, + "source": [ + "def tokenize_function(examples):\n", + " if length==0:\n", + " sen_no_mask = tokenizer(examples[\"sen\"], truncation=True, padding=True)\n", + " else:\n", + " sen_no_mask = tokenizer(examples[\"sen\"], truncation=True, padding='max_length', max_length=length)\n", + " masked_word = tokenizer(examples[\"word2\"], truncation=True, padding=True)['input_ids']\n", + " inputs = []\n", + " labels = []\n", + " \n", + " for i in range(len(masked_word)):\n", + " # Find the word encoding part\n", + " start = masked_word[i].index(101)\n", + " end = masked_word[i].index(102)\n", + " # Extract this part\n", + " temp = [masked_word[i][j] for j in range(start+1,end)]\n", + " # Find the same part in input and mask them\n", + " ipt = sen_no_mask['input_ids'][i][:]\n", + " idx = findindex(ipt, temp)\n", + " for j in range(len(temp)):\n", + " ipt[idx+j] = tokenizer.mask_token_id\n", + " inputs.append(ipt)\n", + " # Find the other part and replace those unmasked indices with -100\n", + " label = sen_no_mask['input_ids'][i][:]\n", + " for j in range(len(ipt)):\n", + " if ipt[j]!=tokenizer.mask_token_id:\n", + " label[j]=-100\n", + " labels.append(label)\n", + " \n", + " # encode synonyms\n", + " synonyms = []\n", + " for i in examples[\"synonyms\"]:\n", + " temp = tokenizer(i)['input_ids']\n", + " synonym = []\n", + " for j in temp:\n", + " synonym.append(j[1])\n", + " synonyms.append(synonym)\n", + "\n", + " sen_no_mask['input_ids']=inputs\n", + " sen_no_mask['label']=labels\n", + " sen_no_mask['synonyms']=synonyms\n", + " return sen_no_mask\n", + "\n", + "length = 0\n", + "tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=[\"sen\", \"word2\", \"word1\"])" + ], + "execution_count": 22, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "woN2JzloVZOP", + "outputId": "a941a153-3374-483c-fcd8-687478d62570" + }, + "source": [ + "for i in datasets['test']['synonyms'][0]:\n", + " print(tokenizer(i)['input_ids'])" + ], + "execution_count": 23, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[101, 1047, 102]\n", + "[101, 2882, 102]\n", + "[101, 1043, 102]\n", + "[101, 4220, 102]\n", + "[101, 1049, 102]\n", + "[101, 1049, 102]\n", + "[101, 1047, 102]\n", + "[101, 15223, 102]\n", + "[101, 4595, 102]\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "YeJ7jSNy1LIg", + "outputId": "9d1c5ba7-058e-4f79-ebd3-67395683303f" + }, + "source": [ + "for i in range(5):\n", + " print(tokenizer.decode(tokenized_datasets['train']['input_ids'][i]))\n", + " print(datasets['train']['sen'][i])\n", + " print(tokenized_datasets['train']['synonyms'][i])\n", + " print(datasets['train']['synonyms'][i])" + ], + "execution_count": 24, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[CLS] cut is close in meaning to [MASK]. [SEP]\n", + "cut is close in meaning to shortened.\n", + "[12641]\n", + "['shortened']\n", + "[CLS] shortened is close in meaning to [MASK]. [SEP]\n", + "shortened is close in meaning to cut.\n", + "[3013]\n", + "['cut']\n", + "[CLS] relative is close in meaning to [MASK]. [SEP]\n", + "relative is close in meaning to comparative.\n", + "[12596]\n", + "['comparative']\n", + "[CLS] comparative is close in meaning to [MASK]. [SEP]\n", + "comparative is close in meaning to relative.\n", + "[5816]\n", + "['relative']\n", + "[CLS] tangible is close in meaning to [MASK]. [SEP]\n", + "tangible is close in meaning to real.\n", + "[2613]\n", + "['real']\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "v9tXfljPUyFu", + "outputId": "08ca33da-bec3-4a5d-ae27-07d50b485d54" + }, + "source": [ + "tokenized_datasets" + ], + "execution_count": 25, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "DatasetDict({\n", + " train: Dataset({\n", + " features: ['attention_mask', 'input_ids', 'label', 'synonyms', 'token_type_ids'],\n", + " num_rows: 21952\n", + " })\n", + " test: Dataset({\n", + " features: ['attention_mask', 'input_ids', 'label', 'synonyms', 'token_type_ids'],\n", + " num_rows: 2496\n", + " })\n", + "})" + ] + }, + "metadata": {}, + "execution_count": 25 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IBAfzIFAYyN7" + }, + "source": [ + "# Train" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "P4aj7nPveYWA" + }, + "source": [ + "from transformers import DataCollatorWithPadding\n", + "collator = DataCollatorWithPadding(tokenizer=tokenizer)" + ], + "execution_count": 26, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "k4rOvTn-U8-X", + "outputId": "5b586138-5f43-45a3-b062-cc95f5a37d14" + }, + "source": [ + "from transformers import AutoConfig, AutoModelForMaskedLM\n", + "\n", + "model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)" + ], + "execution_count": 27, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n", + "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "tk6OxMWA0yHy", + "outputId": "301581c1-57f5-4d0d-9891-356e3f9ea73e" + }, + "source": [ + "model_raw = AutoModelForMaskedLM.from_pretrained(model_checkpoint)" + ], + "execution_count": 28, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n", + "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Km9DRtCyR4p8", + "outputId": "1411470c-9c16-4fc9-ccad-6839a70f9a55" + }, + "source": [ + "from transformers import pipeline\n", + "\n", + "camembert_fill_mask = pipeline(\"fill-mask\", model=model_raw, tokenizer=tokenizer)\n", + "print(tokenizer.decode(tokenized_datasets['test']['input_ids'][287]))\n", + "\n", + "results = camembert_fill_mask(tokenizer.decode(tokenized_datasets['test']['input_ids'][287]))\n", + "results" + ], + "execution_count": 29, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[CLS] approved is close in meaning to [MASK]. [SEP]\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[{'score': 0.04130351543426514,\n", + " 'sequence': 'approved is close in meaning to english.',\n", + " 'token': 2394,\n", + " 'token_str': 'english'},\n", + " {'score': 0.011473696678876877,\n", + " 'sequence': 'approved is close in meaning to latin.',\n", + " 'token': 3763,\n", + " 'token_str': 'latin'},\n", + " {'score': 0.007926260121166706,\n", + " 'sequence': 'approved is close in meaning to earth.',\n", + " 'token': 3011,\n", + " 'token_str': 'earth'},\n", + " {'score': 0.007475481368601322,\n", + " 'sequence': 'approved is close in meaning to it.',\n", + " 'token': 2009,\n", + " 'token_str': 'it'},\n", + " {'score': 0.004972452763468027,\n", + " 'sequence': 'approved is close in meaning to greek.',\n", + " 'token': 3306,\n", + " 'token_str': 'greek'}]" + ] + }, + "metadata": {}, + "execution_count": 29 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4FpnH28cktBa" + }, + "source": [ + "### Check the performance of original PM\n", + "Accurcy is defined as if the true answer is in its top N predictions." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "cSoTHw7f5zu0" + }, + "source": [ + "import torch\n", + "\n", + "N = 5" + ], + "execution_count": 31, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "Qcus2DvVMdot" + }, + "source": [ + "import operator\n", + "from functools import reduce\n", + "\n", + "def common_member(a, b):\n", + " a_set = set(a)\n", + " b_set = set(b)\n", + " if (a_set & b_set):\n", + " return True \n", + " else:\n", + " return False\n", + "\n", + "def check_accuracy(df, lb, outputs, n=5):\n", + " mask_token_index = torch.where(df['input_ids'] == tokenizer.mask_token_id)[1]\n", + " mask_token_logits = outputs.logits[range(outputs.logits.size()[0]), mask_token_index, :]\n", + " top_n_tokens = torch.topk(mask_token_logits, n, dim=1).indices.tolist()\n", + " return [1 if common_member(i, top_n_tokens[j]) else 0 for j,i in enumerate(lb)]\n", + "\n", + "def get_accuracy(df, shards):\n", + " accuracies = []\n", + " for i in range(shards):\n", + " pm_inputs = df.shard(num_shards=shards, index=i)\n", + " test_inputs = dict((k, torch.LongTensor(pm_inputs[k])) for k in ['attention_mask', 'input_ids', 'token_type_ids'])\n", + " test_labels = torch.LongTensor(pm_inputs['label'])\n", + " outputs = model(**test_inputs, labels=test_labels)\n", + " accuracy = check_accuracy(test_inputs, pm_inputs['synonyms'], outputs, N)\n", + " accuracies.append(accuracy)\n", + " return reduce(operator.concat, accuracies)" + ], + "execution_count": 32, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "4-RhqJmaGhv8" + }, + "source": [ + "accuracy = get_accuracy(tokenized_datasets['test'], 20)" + ], + "execution_count": 33, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Nouv8MjaV0bU" + }, + "source": [ + "The accuracy is: `sum(accuracy)/len(accuracy)`" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "-7TxVQmnR7n4", + "outputId": "0742cdba-e3d3-449f-c540-531cad0d0135" + }, + "source": [ + "sum(accuracy)/len(accuracy)" + ], + "execution_count": 34, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "0.15224358974358973" + ] + }, + "metadata": {}, + "execution_count": 34 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 947 + }, + "id": "ukSbElZ8X9xb", + "outputId": "4eb1b0e9-b42d-4f3f-cee9-e0ce7610e00b" + }, + "source": [ + "from transformers import Trainer, TrainingArguments\n", + "from math import ceil\n", + "\n", + "batch_size=16\n", + "logging_steps=ceil(tokenized_datasets['train'].num_rows/batch_size)\n", + "num_train_epochs=5\n", + "\n", + "training_args = TrainingArguments(\n", + " \"test-clm\",\n", + " evaluation_strategy = \"epoch\",\n", + " learning_rate=2e-6,\n", + " weight_decay=0.01,\n", + " per_device_train_batch_size=batch_size,\n", + " per_device_eval_batch_size=batch_size,\n", + " num_train_epochs=num_train_epochs,\n", + " logging_steps=logging_steps,\n", + " save_steps=3000,\n", + ")\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=tokenized_datasets[\"train\"],\n", + " eval_dataset=tokenized_datasets[\"test\"],\n", + " data_collator=collator\n", + ")\n", + "\n", + "trainer.train()" + ], + "execution_count": 35, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "The following columns in the training set don't have a corresponding argument in `BertForMaskedLM.forward` and have been ignored: synonyms.\n", + "***** Running training *****\n", + " Num examples = 21952\n", + " Num Epochs = 5\n", + " Instantaneous batch size per device = 16\n", + " Total train batch size (w. parallel, distributed & accumulation) = 16\n", + " Gradient Accumulation steps = 1\n", + " Total optimization steps = 6860\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [6860/6860 07:32, Epoch 5/5]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
EpochTraining LossValidation Loss
16.8026005.952183
26.2752005.829013
36.0484005.781848
45.9206005.764913
55.8528005.764584

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "The following columns in the evaluation set don't have a corresponding argument in `BertForMaskedLM.forward` and have been ignored: synonyms.\n", + "***** Running Evaluation *****\n", + " Num examples = 2496\n", + " Batch size = 16\n", + "The following columns in the evaluation set don't have a corresponding argument in `BertForMaskedLM.forward` and have been ignored: synonyms.\n", + "***** Running Evaluation *****\n", + " Num examples = 2496\n", + " Batch size = 16\n", + "Saving model checkpoint to test-clm/checkpoint-3000\n", + "Configuration saved in test-clm/checkpoint-3000/config.json\n", + "Model weights saved in test-clm/checkpoint-3000/pytorch_model.bin\n", + "The following columns in the evaluation set don't have a corresponding argument in `BertForMaskedLM.forward` and have been ignored: synonyms.\n", + "***** Running Evaluation *****\n", + " Num examples = 2496\n", + " Batch size = 16\n", + "The following columns in the evaluation set don't have a corresponding argument in `BertForMaskedLM.forward` and have been ignored: synonyms.\n", + "***** Running Evaluation *****\n", + " Num examples = 2496\n", + " Batch size = 16\n", + "Saving model checkpoint to test-clm/checkpoint-6000\n", + "Configuration saved in test-clm/checkpoint-6000/config.json\n", + "Model weights saved in test-clm/checkpoint-6000/pytorch_model.bin\n", + "The following columns in the evaluation set don't have a corresponding argument in `BertForMaskedLM.forward` and have been ignored: synonyms.\n", + "***** Running Evaluation *****\n", + " Num examples = 2496\n", + " Batch size = 16\n", + "\n", + "\n", + "Training completed. Do not forget to share your model on huggingface.co/models =)\n", + "\n", + "\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "TrainOutput(global_step=6860, training_loss=6.179916550883746, metrics={'train_runtime': 452.7878, 'train_samples_per_second': 242.409, 'train_steps_per_second': 15.151, 'total_flos': 564245317440000.0, 'train_loss': 6.179916550883746, 'epoch': 5.0})" + ] + }, + "metadata": {}, + "execution_count": 35 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "K6tyTGL_f3X9" + }, + "source": [ + "### Take a look at the test" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "0Igs5Uai3XZU" + }, + "source": [ + "j=287\n", + "testdata = tokenized_datasets['test'].select([j]).remove_columns(\"label\")" + ], + "execution_count": 36, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "UIP8wsYi3mm1", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "3955f736-2e9d-4aa2-e88f-aaf96e17ea42" + }, + "source": [ + "for i in [j]:\n", + " print(tokenizer.decode(tokenized_datasets['test']['input_ids'][i]))\n", + " print(datasets['test']['synonyms'][i])" + ], + "execution_count": 37, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[CLS] approved is close in meaning to [MASK]. [SEP]\n", + "['approve', 'okay', 'sanctioned']\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "UXPSIUOu2S5E", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 106 + }, + "outputId": "12c3b186-eb46-4271-fa0c-6bd32dff0430" + }, + "source": [ + "temp = trainer.predict(testdata)" + ], + "execution_count": 38, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "The following columns in the test set don't have a corresponding argument in `BertForMaskedLM.forward` and have been ignored: synonyms.\n", + "***** Running Prediction *****\n", + " Num examples = 1\n", + " Batch size = 16\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " [1/1 : < :]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {} + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "bT0kbTng_IUa", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "88582de8-12a8-458b-9bd8-f9921fa0c0e3" + }, + "source": [ + "mask_token_index = torch.where(torch.FloatTensor(testdata['input_ids'][0]) == tokenizer.mask_token_id)[0]\n", + "mask_token_index" + ], + "execution_count": 39, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([7])" + ] + }, + "metadata": {}, + "execution_count": 39 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "YMsl7xrquG7J", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "b7b1ac84-7953-4d7a-8dcf-012c8afb0459" + }, + "source": [ + "mask_token_logits = torch.FloatTensor(temp.predictions)[0, mask_token_index, :]\n", + "mask_token_logits" + ], + "execution_count": 40, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([[-0.3223, -0.4032, -0.2217, ..., -1.3434, -0.8945, -1.2912]])" + ] + }, + "metadata": {}, + "execution_count": 40 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "_oyQuKwfuNUY" + }, + "source": [ + "top_N_tokens = torch.topk(mask_token_logits, N, dim=1).indices[0].tolist()" + ], + "execution_count": 41, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "c9uhPYTAulOJ", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "d24c8acf-7cd8-4579-b644-39d24f83f601" + }, + "source": [ + "for i in top_N_tokens:\n", + " print(tokenizer.decode(i))" + ], + "execution_count": 42, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "approved\n", + "accepted\n", + "recognized\n", + "sanctioned\n", + "accredited\n" + ] + } + ] + } + ] +} \ No newline at end of file