forked from salesforce/CodeRL
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1072ee5
commit 2412a3f
Showing
23 changed files
with
3,924 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
{ | ||
"fp16": { | ||
"enabled": "auto", | ||
"loss_scale": 0, | ||
"loss_scale_window": 1000, | ||
"initial_scale_power": 16, | ||
"hysteresis": 2, | ||
"min_loss_scale": 1 | ||
}, | ||
|
||
"zero_optimization": { | ||
"stage": 2, | ||
"allgather_partitions": true, | ||
"allgather_bucket_size": 2e8, | ||
"overlap_comm": true, | ||
"reduce_scatter": true, | ||
"reduce_bucket_size": 2e8, | ||
"contiguous_gradients": true, | ||
"cpu_offload": true | ||
}, | ||
|
||
"zero_allow_untested_optimizer": true, | ||
"dump_state": false, | ||
|
||
"gradient_accumulation_steps": "auto", | ||
"gradient_clipping": "auto", | ||
"steps_per_print": 2000, | ||
"train_batch_size": "auto", | ||
"train_micro_batch_size_per_gpu": "auto", | ||
"wall_clock_breakdown": false, | ||
|
||
"optimizer": { | ||
"type": "AdamW", | ||
"params": { | ||
"lr": "auto", | ||
"betas": "auto", | ||
"eps": "auto", | ||
"weight_decay": "auto" | ||
} | ||
}, | ||
|
||
"scheduler": { | ||
"type": "WarmupLR", | ||
"params": { | ||
"warmup_min_lr": "auto", | ||
"warmup_max_lr": "auto", | ||
"warmup_num_steps": "auto" | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# | ||
# Copyright (c) 2022, salesforce.com, inc. | ||
# All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | ||
# | ||
|
||
|
||
import argparse | ||
|
||
parser = argparse.ArgumentParser(description="Training a critic model for code generation") | ||
parser.add_argument('--model', default='codet5-base', type=str, help='type of transformers model as model backbone') | ||
parser.add_argument('--model_path', default=None, type=str, help='path to model backbone pretrained weights') | ||
parser.add_argument('--save_dir', default=None, type=str, help='path to save trained critic model checkpoints') | ||
|
||
# Dataloading | ||
parser.add_argument('--train-path', default='data/APPS/train/', type=str, help='path to training data') | ||
parser.add_argument('--sample-mode', default='uniform_sol', help='sampling output programs following a uniform distribution by program population') | ||
|
||
# Training | ||
parser.add_argument('--tuning_mode', default='critic', type=str, help='tuning mode for training LMs') | ||
parser.add_argument('--epochs', default=10, type=int, help='total number of training epochs') | ||
parser.add_argument('--lr', default=5e-5, type=float, help='training learning rate') | ||
parser.add_argument('--batch-size-per-replica', default=4, type=int, help='batch size per GPU') | ||
parser.add_argument('--grad-acc-steps', default=8, type=int, help='number of training steps before each gradient update') | ||
parser.add_argument('--deepspeed', default = None, type=str, help='path to deepspeed configuration file; set None if not using deepspeed') | ||
parser.add_argument('--fp16', default=True, action='store_true', help='set 16-bit training to reduce memory usage') | ||
parser.add_argument('--local_rank', default=-1, type=int) | ||
parser.add_argument('--db', default=False, action='store_true', help='set to turn on debug mode i.e. using dummy small data split and only 1 data worker') | ||
|
||
# Logging | ||
parser.add_argument('--log-freq', default=1, type=int, help='save training log after this number of training steps') | ||
parser.add_argument('--save-freq', default=200, type=int, help='save model checkpoints after this number of training steps') | ||
parser.add_argument('--save_total_limit', default=2, type=int, help='total of number checkpoints to keep; only keep the latest ones') | ||
|
||
args = parser.parse_args() | ||
|
||
if args.save_dir is None: | ||
args.save_dir = '{}_{}_bs{}x{}_lr{}'.format( | ||
args.model, args.tuning_mode, | ||
args.batch_size_per_replica, args.grad_acc_steps, args.lr | ||
) | ||
|
||
if args.db: | ||
args.save_dir = 'exps/test/{}'.format(args.save_dir) | ||
else: | ||
args.save_dir = 'exps/{}'.format(args.save_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
[{"code": "q = int(input())\n\nfor i in range(q):\n\t(x, y, k) = list(map(int, input().split()))\n\n\tif max(x, y) > k:\n\t\tprint(-1)\n\telif x == y and k == x + 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 1 and y % 2 == 1 and k % 2 == 0:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 0 and y % 2 == 0 and k % 2 == 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif (x + y) % 2 == 0:\n\t\tprint(k)\n\telse:\n\t\tprint(k - 1)\n", "result": true, "error_type": "passed unit tests"}, {"code": "q = int(input())\nfor i in range(q):\n\tn, m, k = map(int, input().split())\n\tost = max(n, m) - min(n, m)\n\tplus = 0\n\tif ost % 2!= 0:\n\t\tplus = 1\n\t\tost -= 1\n\tmini = min(n, m) + ost + plus\n\t#print('mini:'+ str(mini))\n\tif k < mini:\n\t\tprint(-1)\n\telif (k - mini) % 2 == 0 or plus == 1:\n\t\tprint(k - plus)\n\telse:\n\t\tprint(k - plus - 2)\t\n", "result": true, "error_type": "passed unit tests"}, {"code": "q = int(input())\n\nfor i in range(q):\n\t(x, y, k) = list(map(int, input().split()))\n\n\tif max(x, y) > k:\n\t\tprint(-1)\n\telif x == y and k == x + 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 1 and y % 2 == 1 and k % 2 == 0:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 0 and y % 2 == 0 and k % 2 == 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif (x + y) % 2 == 0:\n\t\tprint(k)\n\telse:\n\t\tprint(k - 1)\n", "result": true, "error_type": "passed unit tests"}, {"code": "q = int(input())\n\nfor i in range(q):\n\t(x, y, k) = list(map(int, input().split()))\n\n\tif max(x, y) > k:\n\t\tprint(-1)\n\telif x == y and k == x + 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 1 and y % 2 == 1 and k % 2 == 0:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 0 and y % 2 == 0 and k % 2 == 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif (x + y) % 2 == 0:\n\t\tprint(k)\n\telse:\n\t\tprint(k - 1)\n", "result": true, "error_type": "passed unit tests"}, {"code": "q = int(input())\nfor i in range(q):\n\ta, b, k = map(int, input().split())\n\tif a < b:\n\t\ta, b = b, a\n\tif a % 2 == 0 and b % 2 == 0:\n\t\tprint(k - 1)\n\telif a % 2 == 1 and b % 2 == 1:\n\t\tprint(k - 2)\n\telse:\n\t\tprint(-1)\n", "result": false, "error_type": "failed unit tests"}, {"code": "q=int(input())\n\nfor i in range(q):\n\tn,m,k=list(map(int,input().split()))\n\n\tif max(n,m)>k:\n\t\tprint(-1)\n\n\telse:\n\t\tif n%2==0 and m%2==0:\n\t\t\tif k%2==0:\n\t\t\t\tprint(k)\n\t\t\telse:\n\t\t\t\tprint(k-2)\n\n\t\telif max(n%2,m%2)>k%2:\n\t\t\tprint(k-2)\n\n\t\telse:\n\t\t\tprint(k-1)\n", "result": true, "error_type": "passed unit tests"}, {"code": "q=int(input())\n\nfor i in range(q):\n\tn,m,k=list(map(int,input().split()))\n\n\tif n>k or m>k:\n\t\tprint(-1)\n\n\telse:\n\n\t\tif n%2==0 and m%2==0:\n\t\t\tif k%2==0:\n\t\t\t\tprint(k)\n\t\t\telse:\n\t\t\t\tprint(k-2)\n\n\t\telif (n%2==0 and m%2==1) or (n%2==1 and m%2==0):\n\t\t\tprint(k-1)\n\n\t\telif n%2==1 and m%2==1:\n\t\t\tif k%2==0:\n\t\t\t\tprint(k-2)\n\t\t\telse:\n\t\t\t\tprint(k)\n", "result": true, "error_type": "passed unit tests"}, {"code": "q = int(input())\nfor i in range(q):\n\tn, m, k = map(int, input().split())\n\tost = max(n, m) - min(n, m)\n\tplus = 0\n\tif ost % 2!= 0:\n\t\tplus = 1\n\t\tost -= 1\n\tmini = min(n, m) + ost + plus\n\t#print('mini:'+ str(mini))\n\tif k < mini:\n\t\tprint(-1)\n\telif (k - mini) % 2 == 0:\n\t\tprint(k - mini)\n\telse:\n\t\tprint(k - mini) # possible\n", "result": false, "error_type": "failed unit tests"}, {"code": "q=int(input())\n\nfor i in range(q):\n\tn,m,k=list(map(int,input().split()))\n\n\tif n>k or m>k:\n\t\tprint(-1)\n\n\telse:\n\n\t\tif n%2==0 and m%2==0:\n\t\t\tif k%2==0:\n\t\t\t\tprint(k)\n\t\t\telse:\n\t\t\t\tprint(k-2)\n\n\t\telif (n%2==0 and m%2==1) or (n%2==1 and m%2==0):\n\t\t\tprint(k-1)\n\n\t\telif n%2==1 and m%2==1:\n\t\t\tif k%2==0:\n\t\t\t\tprint(k-2)\n\t\t\telse:\n\t\t\t\tprint(k)\n", "result": true, "error_type": "passed unit tests"}, {"code": "\n\nq = int(input())\n\nfor _ in range(q):\n\tn, m, k = list(map(int, input().split()))\n\tif max([n, m]) > k:\n\t\tprint(-1)\n\telse:\n\t\tif 0 == (n + m) % 2:\n\t\t\tif k % 2 == max([n, m]) % 2:\n\t\t\t\tprint(k)\n\t\t\telse:\n\t\t\t\tprint(k - 2)\n\t\telse:\n\t\t\tprint(k - 1)\n", "result": true, "error_type": "passed unit tests"}, {"code": "q=int(input())\n\nfor i in range(q):\n\tn,m,k=list(map(int,input().split()))\n\n\tif max(n,m)>k:\n\t\tprint(-1)\n\n\telse:\n\t\tif n%2==0 and m%2==0:\n\t\t\tif k%2==0:\n\t\t\t\tprint(k)\n\t\t\telse:\n\t\t\t\tprint(k-2)\n\n\t\telif max(n%2,m%2)>k%2:\n\t\t\tprint(k-2)\n\n\t\telse:\n\t\t\tprint(k)\n", "result": false, "error_type": "failed unit tests"}, {"code": "q=int(input())\n\nfor i in range(q):\n\tn,m,k=list(map(int,input().split()))\n\n\tif max(n,m)>k:\n\t\tprint(-1)\n\n\telse:\n\t\tif n%2==0 and m%2==0:\n\t\t\tif k%2==0:\n\t\t\t\tprint(k)\n\t\t\telse:\n\t\t\t\tprint(k-2)\n\n\t\telif max(n%2,m%2)>k%2:\n\t\t\tprint(k-2)\n\n\t\telse:\n\t\t\tprint(k)\n", "result": false, "error_type": "failed unit tests"}, {"code": "\nimport sys\n#sys.stdin=open(\"data.txt\")\ninput=sys.stdin.readline\n\nfor _ in range(int(input())):\n\tn,m,k=list(map(int,input().split()))\n\tn,m=sorted([n,m])\n\tif n%2==1 and m%2==1:\n\t\tprint(-1)\n\t\tcontinue\n\tif n%2==0 and m%2==0:\n\t\tprint(k)\n\t\tcontinue\n\tif n%2==1 and m%2==1:\n\t\tprint(k-2)\n\t\tcontinue\n\tif n%2==0 and m%2==0:\n\t\tprint(k-1)\n\t\tcontinue\n\tif n%2==1 and m%2==0:\n\t\tprint(k-1)\n\t\tcontinue\n\tif n%2==0 and m%2==1:\n\t\tprint(k-2)\n\t\tcontinue\n\nprint(sum(n%2==m%2 for n in range(k+1)))\n", "result": false, "error_type": "failed unit tests"}, {"code": "q=int(input())\n\nfor i in range(q):\n\tn,m,k=list(map(int,input().split()))\n\n\tif n>k or m>k:\n\t\tprint(-1)\n\n\telse:\n\t\tif n%2==0 and m%2==0:\n\t\t\tif k%2==0:\n\t\t\t\tprint(k)\n\t\t\telse:\n\t\t\t\tprint(k-2)\n\n\t\telif (n%2==0 and m%2==1) or (n%2==1 and m%2==0):\n\t\t\tprint(k-1)\n\n\t\telif n%2==1 and m%2==1:\n\t\t\tif k%2==0:\n\t\t\t\tprint(k-2)\n\t\t\telse:\n\t\t\t\tprint(k)\n", "result": true, "error_type": "passed unit tests"}, {"code": "q = int(input())\n\nfor i in range(q):\n\t(x, y, k) = list(map(int, input().split()))\n\n\tif max(x, y) > k:\n\t\tprint(-1)\n\telif x == y and k == x + 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 1 and y % 2 == 1 and k % 2 == 0:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 0 and y % 2 == 0 and k % 2 == 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif (x + y) % 2 == 0:\n\t\tprint(k)\n\telse:\n\t\tprint(k - 1)\n", "result": true, "error_type": "passed unit tests"}, {"code": "q=int(input())\n\nfor i in range(q):\n\tn,m,k=list(map(int,input().split()))\n\n\tif max(n,m)>k:\n\t\tprint(-1)\n\n\telse:\n\t\tif n%2==0 and m%2==0:\n\t\t\tif k%2==0:\n\t\t\t\tprint(k)\n\t\t\telse:\n\t\t\t\tprint(k-2)\n\n\t\telif max(n%2,m%2)>k%2:\n\t\t\tprint(k-2)\n\n\t\telse:\n\t\t\tprint(k)\n", "result": false, "error_type": "failed unit tests"}, {"code": "q = int(input())\n\nfor i in range(q):\n\t(x, y, k) = list(map(int, input().split()))\n\n\tif max(x, y) > k:\n\t\tprint(-1)\n\telif x == y and k == x + 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 1 and y % 2 == 1 and k % 2 == 0:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 0 and y % 2 == 0 and k % 2 == 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif (x + y) % 2 == 0:\n\t\tprint(k)\n\telse:\n\t\tprint(k - 1)\n", "result": true, "error_type": "passed unit tests"}, {"code": "q=int(input())\n\nfor i in range(q):\n\tn,m,k=list(map(int,input().split()))\n\n\tif max(n,m)>k:\n\t\tprint(-1)\n\n\telse:\n\t\tif n%2==0 and m%2==0:\n\t\t\tif k%2==0:\n\t\t\t\tprint(k)\n\t\t\telse:\n\t\t\t\tprint(k-2)\n\n\t\telif max(n,m)%2==0 and min(n,m)%2==0:\n\t\t\tprint(k-1)\n\n\t\telse:\n\t\t\tprint(k-1)\n", "result": true, "error_type": "passed unit tests"}, {"code": "q = int(input())\n\nfor i in range(q):\n\t(x, y, k) = list(map(int, input().split()))\n\n\tif max(x, y) > k:\n\t\tprint(-1)\n\telif x == y and k == x + 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 1 and y % 2 == 1 and k % 2 == 0:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 0 and y % 2 == 0 and k % 2 == 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif (x + y) % 2 == 0:\n\t\tprint(k)\n\telse:\n\t\tprint(k - 1)\n", "result": true, "error_type": "passed unit tests"}, {"code": "q = int(input())\n\nfor i in range(q):\n\t(x, y, k) = list(map(int, input().split()))\n\n\tif max(x, y) > k:\n\t\tprint(-1)\n\telif x == y and k == x + 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 1 and y % 2 == 1 and k % 2 == 0:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif x % 2 == 0 and y % 2 == 0 and k % 2 == 1:\n\t\tprint(k - 2)\n\t\tcontinue\n\telif (x + y) % 2 == 0:\n\t\tprint(k)\n\telse:\n\t\tprint(k - 1)\n", "result": true, "error_type": "passed unit tests"}] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
{ | ||
"inputs": [ | ||
"3\n2 2 3\n4 3 7\n10 1 9\n" | ||
], | ||
"outputs": [ | ||
"1\n6\n-1\n" | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"difficulty": "interview", "url": "https://codeforces.com/problemset/problem/1036/B"} |
Oops, something went wrong.