Skip to content

Commit

Permalink
add code for training critic
Browse files Browse the repository at this point in the history
  • Loading branch information
henryhungle committed Aug 22, 2022
1 parent 1072ee5 commit 2412a3f
Show file tree
Hide file tree
Showing 23 changed files with 3,924 additions and 6 deletions.
38 changes: 36 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Authors:
* [x] [Generating Programs](###generate)
* [x] [Running Unit Tests](###runtests)
* [x] [Evaluating Programs](###evaluate)
* [x] [Training Critic](###trainingcritic)
* [ ] [Generating Programs with Critic Sampling](###criticsampling)
* [x] [Example Generated Programs](##exampleprogram)
* [x] [Citation](##cite)
Expand Down Expand Up @@ -64,6 +65,14 @@ The code requires some dependencies as specified in `requirements.txt`. Please f

`pip install -r requirements.txt`

Install the `transformers` library from the source code (the current source code is developed from the original [code](https://github.com/huggingface/transformers) of version 4.16.1):

```
cd transformers
pip install -e .
```


## Datasets <a name="datasets"></a>

For pretraining, apart from the [CodeSearchNet (CSN)](https://arxiv.org/abs/1909.09436), we use the [Python Github Code Dataset (GCPY)](https://huggingface.co/datasets/lvwerra/github-code).
Expand Down Expand Up @@ -114,7 +123,7 @@ We created `scripts/generate.sh` to generate programs on the APPS benchmark. You
| `num_seqs_per_iter` | Depending on the limit of GPU, we can generate multiple rounds, each with this number of output programs | 50 |
| `temp` | temperature for sampling generation | 0.6 ||

Other parameters are defined in the file `utils/generate_config.py`.
Other parameters are defined in the file `utils/generate_configs.py`.

Running the generation script will output programs, each of which is saved into a `json` file, including data fields `code` (list of output programs) and `prompt` (constructed input sequence to the LM model).

Expand Down Expand Up @@ -146,11 +155,36 @@ Compared to the original implementation from APPS, we adopt one trick which will
### Evaluating Programs <a name="evaluate"></a>
To compute the pass@k metrics, rather than using the APPS evaluation metrics, we follow the official implementation of the [HumanEval benchmark](https://github.com/openai/human-eval) (which better measures pass@k normalized by the number of possible k programs)


### Training Critic <a name="trainingcritic"></a>

We can train a critic model as a classifier that predicts the test outcomes of generated samples. For each training sample, we can follow the prior processes to generate programs and evaluate them with available unit tests. On average, we generate 20 programs per training sample (we provided some example generated programs in `data/APPS/train/`).

Once the programs are tested, we can used their test outcomes as annotations to train a critic model initialized from a LM pretrained on source code data (we used CodeT5-based in this case).

We created `scripts/train_critic.sh` and `scripts/train_critic_deepspeed.sh` to train a critic using generated programs. You can directly run this file by configuring the following parameters:

| **Parameters** | **Description** | **Example Values** |
|:--------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------:|
| `batch-size-per-replica` | Number of training samples per GPU device | 8 |
| `grad-acc-steps` | Gradient accumulation steps | 1 |
| `epochs` | Number of training epochs | 10 |
| `lr` | Learning rate | 2e-5 |
| `save-freq` | Save model checkpoints after this number of training steps | 1000 |
| `log-freq` | Save model training losses after this number of training steps | 10 |
| `save_total_limit` | Total number of checkpoints to keep eventually (only the latest ones are kept) | 5 |
| `fp16` | Enable this to training model in 16-bit mode to reduce memory usage | N/A |
| `deepspeed` | If using deepspeed, set this parameter to the configuration file for deepspeed training | configs/deepspeed_configs.json |
| `db` | Enable this to train in debugging mode i.e. with small dummy data split and only 1 data worker | N/A |

Other parameters are defined in the file `utils/train_critic_configs.py`.

Running the script will train a critic model as a classifier that receives inputs as a problem description + a generated program and returns an output as one of 4 test outcomes: compile error, runtime error, failed tests, and passed tests. The model checkpoints are saved in a folder under `exps/`.

### Generating Programs with Critic Sampling <a name="criticsampling"></a>

We will release the implementation details of our critic sampling procedure.


## Example Generated Programs <a name="exampleprogram"></a>

<p align="center">
Expand Down
50 changes: 50 additions & 0 deletions configs/deepspeed_configs.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},

"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true,
"cpu_offload": true
},

"zero_allow_untested_optimizer": true,
"dump_state": false,

"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false,

"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},

"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
}
}
47 changes: 47 additions & 0 deletions configs/train_critic_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#


import argparse

parser = argparse.ArgumentParser(description="Training a critic model for code generation")
parser.add_argument('--model', default='codet5-base', type=str, help='type of transformers model as model backbone')
parser.add_argument('--model_path', default=None, type=str, help='path to model backbone pretrained weights')
parser.add_argument('--save_dir', default=None, type=str, help='path to save trained critic model checkpoints')

# Dataloading
parser.add_argument('--train-path', default='data/APPS/train/', type=str, help='path to training data')
parser.add_argument('--sample-mode', default='uniform_sol', help='sampling output programs following a uniform distribution by program population')

# Training
parser.add_argument('--tuning_mode', default='critic', type=str, help='tuning mode for training LMs')
parser.add_argument('--epochs', default=10, type=int, help='total number of training epochs')
parser.add_argument('--lr', default=5e-5, type=float, help='training learning rate')
parser.add_argument('--batch-size-per-replica', default=4, type=int, help='batch size per GPU')
parser.add_argument('--grad-acc-steps', default=8, type=int, help='number of training steps before each gradient update')
parser.add_argument('--deepspeed', default = None, type=str, help='path to deepspeed configuration file; set None if not using deepspeed')
parser.add_argument('--fp16', default=True, action='store_true', help='set 16-bit training to reduce memory usage')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--db', default=False, action='store_true', help='set to turn on debug mode i.e. using dummy small data split and only 1 data worker')

# Logging
parser.add_argument('--log-freq', default=1, type=int, help='save training log after this number of training steps')
parser.add_argument('--save-freq', default=200, type=int, help='save model checkpoints after this number of training steps')
parser.add_argument('--save_total_limit', default=2, type=int, help='total of number checkpoints to keep; only keep the latest ones')

args = parser.parse_args()

if args.save_dir is None:
args.save_dir = '{}_{}_bs{}x{}_lr{}'.format(
args.model, args.tuning_mode,
args.batch_size_per_replica, args.grad_acc_steps, args.lr
)

if args.db:
args.save_dir = 'exps/test/{}'.format(args.save_dir)
else:
args.save_dir = 'exps/{}'.format(args.save_dir)
4 changes: 2 additions & 2 deletions configs/unit_test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
parser.add_argument("--output_path", type=str, help="Path to output test results")
parser.add_argument("--code_path", type=str, help='Path to generated programs')

parser.add_argument("-i", "--index", default=-1, type=int)
parser.add_argument("-d", "--debug", action="store_true")
parser.add_argument("-i", "--index", default=-1, type=int, help='specific sample index to be tested against unit tests')
parser.add_argument("-d", "--debug", action="store_true", help='test in debugging mode with printout messages')
parser.add_argument('--max_tests', type=int, default=-1, help='Filter for test samples by maximum number of unit tests')
parser.add_argument('--example_tests', type=int, default=0, help='0: run hidden unit tests; 1: run example unit tests')

Expand Down
1 change: 1 addition & 0 deletions data/APPS/train/0001/gen_solutions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"code": "q = int(input())\n\nfor i in range(q):\n\t(x, y, k) = list(map(int, input().split()))\n\n\tif max(x, y) > k:\n\t\tprint(-1)\n\telif x == y and k == x + 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 1 and y % 2 == 1 and k % 2 == 0:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 0 and y % 2 == 0 and k % 2 == 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif (x + y) % 2 == 0:\n\t\tprint(k)\n\telse:\n\t\tprint(k - 1)\n", "result": true, "error_type": "passed unit tests"}, {"code": "q = int(input())\nfor i in range(q):\n\tn, m, k = map(int, input().split())\n\tost = max(n, m) - min(n, m)\n\tplus = 0\n\tif ost % 2!= 0:\n\t\tplus = 1\n\t\tost -= 1\n\tmini = min(n, m) + ost + plus\n\t#print('mini:'+ str(mini))\n\tif k < mini:\n\t\tprint(-1)\n\telif (k - mini) % 2 == 0 or plus == 1:\n\t\tprint(k - plus)\n\telse:\n\t\tprint(k - plus - 2)\t\n", "result": true, "error_type": "passed unit tests"}, {"code": "q = int(input())\n\nfor i in range(q):\n\t(x, y, k) = list(map(int, input().split()))\n\n\tif max(x, y) > k:\n\t\tprint(-1)\n\telif x == y and k == x + 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 1 and y % 2 == 1 and k % 2 == 0:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 0 and y % 2 == 0 and k % 2 == 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif (x + y) % 2 == 0:\n\t\tprint(k)\n\telse:\n\t\tprint(k - 1)\n", "result": true, "error_type": "passed unit tests"}, {"code": "q = int(input())\n\nfor i in range(q):\n\t(x, y, k) = list(map(int, input().split()))\n\n\tif max(x, y) > k:\n\t\tprint(-1)\n\telif x == y and k == x + 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 1 and y % 2 == 1 and k % 2 == 0:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 0 and y % 2 == 0 and k % 2 == 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif (x + y) % 2 == 0:\n\t\tprint(k)\n\telse:\n\t\tprint(k - 1)\n", "result": true, "error_type": "passed unit tests"}, {"code": "q = int(input())\nfor i in range(q):\n\ta, b, k = map(int, input().split())\n\tif a < b:\n\t\ta, b = b, a\n\tif a % 2 == 0 and b % 2 == 0:\n\t\tprint(k - 1)\n\telif a % 2 == 1 and b % 2 == 1:\n\t\tprint(k - 2)\n\telse:\n\t\tprint(-1)\n", "result": false, "error_type": "failed unit tests"}, {"code": "q=int(input())\n\nfor i in range(q):\n\tn,m,k=list(map(int,input().split()))\n\n\tif max(n,m)>k:\n\t\tprint(-1)\n\n\telse:\n\t\tif n%2==0 and m%2==0:\n\t\t\tif k%2==0:\n\t\t\t\tprint(k)\n\t\t\telse:\n\t\t\t\tprint(k-2)\n\n\t\telif max(n%2,m%2)>k%2:\n\t\t\tprint(k-2)\n\n\t\telse:\n\t\t\tprint(k-1)\n", "result": true, "error_type": "passed unit tests"}, {"code": "q=int(input())\n\nfor i in range(q):\n\tn,m,k=list(map(int,input().split()))\n\n\tif n>k or m>k:\n\t\tprint(-1)\n\n\telse:\n\n\t\tif n%2==0 and m%2==0:\n\t\t\tif k%2==0:\n\t\t\t\tprint(k)\n\t\t\telse:\n\t\t\t\tprint(k-2)\n\n\t\telif (n%2==0 and m%2==1) or (n%2==1 and m%2==0):\n\t\t\tprint(k-1)\n\n\t\telif n%2==1 and m%2==1:\n\t\t\tif k%2==0:\n\t\t\t\tprint(k-2)\n\t\t\telse:\n\t\t\t\tprint(k)\n", "result": true, "error_type": "passed unit tests"}, {"code": "q = int(input())\nfor i in range(q):\n\tn, m, k = map(int, input().split())\n\tost = max(n, m) - min(n, m)\n\tplus = 0\n\tif ost % 2!= 0:\n\t\tplus = 1\n\t\tost -= 1\n\tmini = min(n, m) + ost + plus\n\t#print('mini:'+ str(mini))\n\tif k < mini:\n\t\tprint(-1)\n\telif (k - mini) % 2 == 0:\n\t\tprint(k - mini)\n\telse:\n\t\tprint(k - mini) # possible\n", "result": false, "error_type": "failed unit tests"}, {"code": "q=int(input())\n\nfor i in range(q):\n\tn,m,k=list(map(int,input().split()))\n\n\tif n>k or m>k:\n\t\tprint(-1)\n\n\telse:\n\n\t\tif n%2==0 and m%2==0:\n\t\t\tif k%2==0:\n\t\t\t\tprint(k)\n\t\t\telse:\n\t\t\t\tprint(k-2)\n\n\t\telif (n%2==0 and m%2==1) or (n%2==1 and m%2==0):\n\t\t\tprint(k-1)\n\n\t\telif n%2==1 and m%2==1:\n\t\t\tif k%2==0:\n\t\t\t\tprint(k-2)\n\t\t\telse:\n\t\t\t\tprint(k)\n", "result": true, "error_type": "passed unit tests"}, {"code": "\n\nq = int(input())\n\nfor _ in range(q):\n\tn, m, k = list(map(int, input().split()))\n\tif max([n, m]) > k:\n\t\tprint(-1)\n\telse:\n\t\tif 0 == (n + m) % 2:\n\t\t\tif k % 2 == max([n, m]) % 2:\n\t\t\t\tprint(k)\n\t\t\telse:\n\t\t\t\tprint(k - 2)\n\t\telse:\n\t\t\tprint(k - 1)\n", "result": true, "error_type": "passed unit tests"}, {"code": "q=int(input())\n\nfor i in range(q):\n\tn,m,k=list(map(int,input().split()))\n\n\tif max(n,m)>k:\n\t\tprint(-1)\n\n\telse:\n\t\tif n%2==0 and m%2==0:\n\t\t\tif k%2==0:\n\t\t\t\tprint(k)\n\t\t\telse:\n\t\t\t\tprint(k-2)\n\n\t\telif max(n%2,m%2)>k%2:\n\t\t\tprint(k-2)\n\n\t\telse:\n\t\t\tprint(k)\n", "result": false, "error_type": "failed unit tests"}, {"code": "q=int(input())\n\nfor i in range(q):\n\tn,m,k=list(map(int,input().split()))\n\n\tif max(n,m)>k:\n\t\tprint(-1)\n\n\telse:\n\t\tif n%2==0 and m%2==0:\n\t\t\tif k%2==0:\n\t\t\t\tprint(k)\n\t\t\telse:\n\t\t\t\tprint(k-2)\n\n\t\telif max(n%2,m%2)>k%2:\n\t\t\tprint(k-2)\n\n\t\telse:\n\t\t\tprint(k)\n", "result": false, "error_type": "failed unit tests"}, {"code": "\nimport sys\n#sys.stdin=open(\"data.txt\")\ninput=sys.stdin.readline\n\nfor _ in range(int(input())):\n\tn,m,k=list(map(int,input().split()))\n\tn,m=sorted([n,m])\n\tif n%2==1 and m%2==1:\n\t\tprint(-1)\n\t\tcontinue\n\tif n%2==0 and m%2==0:\n\t\tprint(k)\n\t\tcontinue\n\tif n%2==1 and m%2==1:\n\t\tprint(k-2)\n\t\tcontinue\n\tif n%2==0 and m%2==0:\n\t\tprint(k-1)\n\t\tcontinue\n\tif n%2==1 and m%2==0:\n\t\tprint(k-1)\n\t\tcontinue\n\tif n%2==0 and m%2==1:\n\t\tprint(k-2)\n\t\tcontinue\n\nprint(sum(n%2==m%2 for n in range(k+1)))\n", "result": false, "error_type": "failed unit tests"}, {"code": "q=int(input())\n\nfor i in range(q):\n\tn,m,k=list(map(int,input().split()))\n\n\tif n>k or m>k:\n\t\tprint(-1)\n\n\telse:\n\t\tif n%2==0 and m%2==0:\n\t\t\tif k%2==0:\n\t\t\t\tprint(k)\n\t\t\telse:\n\t\t\t\tprint(k-2)\n\n\t\telif (n%2==0 and m%2==1) or (n%2==1 and m%2==0):\n\t\t\tprint(k-1)\n\n\t\telif n%2==1 and m%2==1:\n\t\t\tif k%2==0:\n\t\t\t\tprint(k-2)\n\t\t\telse:\n\t\t\t\tprint(k)\n", "result": true, "error_type": "passed unit tests"}, {"code": "q = int(input())\n\nfor i in range(q):\n\t(x, y, k) = list(map(int, input().split()))\n\n\tif max(x, y) > k:\n\t\tprint(-1)\n\telif x == y and k == x + 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 1 and y % 2 == 1 and k % 2 == 0:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 0 and y % 2 == 0 and k % 2 == 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif (x + y) % 2 == 0:\n\t\tprint(k)\n\telse:\n\t\tprint(k - 1)\n", "result": true, "error_type": "passed unit tests"}, {"code": "q=int(input())\n\nfor i in range(q):\n\tn,m,k=list(map(int,input().split()))\n\n\tif max(n,m)>k:\n\t\tprint(-1)\n\n\telse:\n\t\tif n%2==0 and m%2==0:\n\t\t\tif k%2==0:\n\t\t\t\tprint(k)\n\t\t\telse:\n\t\t\t\tprint(k-2)\n\n\t\telif max(n%2,m%2)>k%2:\n\t\t\tprint(k-2)\n\n\t\telse:\n\t\t\tprint(k)\n", "result": false, "error_type": "failed unit tests"}, {"code": "q = int(input())\n\nfor i in range(q):\n\t(x, y, k) = list(map(int, input().split()))\n\n\tif max(x, y) > k:\n\t\tprint(-1)\n\telif x == y and k == x + 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 1 and y % 2 == 1 and k % 2 == 0:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 0 and y % 2 == 0 and k % 2 == 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif (x + y) % 2 == 0:\n\t\tprint(k)\n\telse:\n\t\tprint(k - 1)\n", "result": true, "error_type": "passed unit tests"}, {"code": "q=int(input())\n\nfor i in range(q):\n\tn,m,k=list(map(int,input().split()))\n\n\tif max(n,m)>k:\n\t\tprint(-1)\n\n\telse:\n\t\tif n%2==0 and m%2==0:\n\t\t\tif k%2==0:\n\t\t\t\tprint(k)\n\t\t\telse:\n\t\t\t\tprint(k-2)\n\n\t\telif max(n,m)%2==0 and min(n,m)%2==0:\n\t\t\tprint(k-1)\n\n\t\telse:\n\t\t\tprint(k-1)\n", "result": true, "error_type": "passed unit tests"}, {"code": "q = int(input())\n\nfor i in range(q):\n\t(x, y, k) = list(map(int, input().split()))\n\n\tif max(x, y) > k:\n\t\tprint(-1)\n\telif x == y and k == x + 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 1 and y % 2 == 1 and k % 2 == 0:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 0 and y % 2 == 0 and k % 2 == 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif (x + y) % 2 == 0:\n\t\tprint(k)\n\telse:\n\t\tprint(k - 1)\n", "result": true, "error_type": "passed unit tests"}, {"code": "q = int(input())\n\nfor i in range(q):\n\t(x, y, k) = list(map(int, input().split()))\n\n\tif max(x, y) > k:\n\t\tprint(-1)\n\telif x == y and k == x + 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 1 and y % 2 == 1 and k % 2 == 0:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 0 and y % 2 == 0 and k % 2 == 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif (x + y) % 2 == 0:\n\t\tprint(k)\n\telse:\n\t\tprint(k - 1)\n", "result": true, "error_type": "passed unit tests"}]
8 changes: 8 additions & 0 deletions data/APPS/train/0001/input_output.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"inputs": [
"3\n2 2 3\n4 3 7\n10 1 9\n"
],
"outputs": [
"1\n6\n-1\n"
]
}
1 change: 1 addition & 0 deletions data/APPS/train/0001/metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"difficulty": "interview", "url": "https://codeforces.com/problemset/problem/1036/B"}
Loading

0 comments on commit 2412a3f

Please sign in to comment.