Skip to content

Commit

Permalink
code for finetuning model with synthetic samples and their return est…
Browse files Browse the repository at this point in the history
…imates
  • Loading branch information
henryhungle committed Nov 18, 2022
1 parent 38d21d5 commit 51db4ff
Show file tree
Hide file tree
Showing 41 changed files with 410 additions and 47 deletions.
19 changes: 16 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Authors:
* [x] [Training Critic](#training-critic)
* [x] [Generating Critic Scores](#generating-critic-scores)
* [x] [Finetuning with Ground-truth Programs](#finetuning-with-ground-truth-programs)
* [ ] [Finetuning with Generated Programs](#finetuning-with-generated-programs)
* [x] [Finetuning with Generated Programs](#finetuning-with-generated-programs)
* [ ] [Generating Programs with Critic Sampling](#generating-programs-with-critic-sampling)
* [x] [Example Generated Programs](#example-generated-programs)
* [x] [Citation](#citation)
Expand Down Expand Up @@ -208,7 +208,8 @@ We created `scripts/generate_critic_scores.sh` to generate critic scores for syn

Other parameters are defined in the file `utils/generate_configs.py`.

Running the generation script will output programs, each of which is saved into a `pkl` (pickle) file, including data fields `code` (list of programs), `prompt` (constructed input sequence to the critic model), `gt_error_type` (ground-truth test outcomes), `pred_error_type` (predicted test outcomes by critic), `error_hidden_states` (hidden states returned by critic).
Running the generation script will output predictions of the critic model.
For each data sample, the prediction is saved into a `pkl` (pickle) file, including data fields `code` (list of programs), `prompt` (constructed input sequence to the critic model), `gt_error_type` (ground-truth test outcomes), `pred_error_type` (predicted test outcomes by critic), `error_hidden_states` (hidden states returned by critic).

### Finetuning with Ground-truth Programs

Expand All @@ -221,7 +222,19 @@ The model checkpoints are saved in a folder under `exps/`.

### Finetuning with Generated Programs

TBD
We created `scripts/train_actor_rl.sh` and `scripts/train_actor_rl_deepspeed.sh` to train pretrained LMs with synthetic generated programs.
We use the parameters as defined above in the [critic training process](#training-critic) with the following additional parameters:

| **Parameters** | **Description** | **Example Values** |
|:-----------------:|:--------------------------------------------------------------------------------------------------------:|:------------------------------:|
| `model_path` | Path to a finetuned model checkpoint e.g. from warm-up training | models/codet5_finetuned_codeRL |
| `relative_returns` | Enable this to consider a baseline to compute relative return estimates rather than absolute return restimates in the RL loss| N/A |

Other parameters are defined in the file `utils/train_configs.py`.


Running the script will load a finetuned CodeT5-large model and continue to train it with both generated programs as well as ground-truth programs in alternative training steps.
The model checkpoints are saved in a folder under `exps/`.

### Generating Programs with Critic Sampling

Expand Down
7 changes: 6 additions & 1 deletion configs/train_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@
parser.add_argument('--train-path', default='data/APPS/train/', type=str, help='path to training data')
parser.add_argument('--sample-mode', default='uniform_sol', help='sampling output programs following a uniform distribution by program population')

# Training
# Model
parser.add_argument('--tuning_mode', default='critic', type=str, help='tuning mode for training LMs')
parser.add_argument('--relative_returns', default=False, action='store_true', help='use relative returns against a baseline during RL')
parser.add_argument('--clone_rl_head', default=False, action='store_true', help='Optional: clone a seperate linear layer for RL samples and initialize it from finetuned LM head')


# Training
parser.add_argument('--epochs', default=10, type=int, help='total number of training epochs')
parser.add_argument('--lr', default=5e-5, type=float, help='training learning rate')
parser.add_argument('--batch-size-per-replica', default=4, type=int, help='batch size per GPU')
Expand Down
1 change: 1 addition & 0 deletions data/APPS/train/0001/baseline_solutions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"code": "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\nstdin = sys.stdin\nstdout = sys.stdout\ndef code():\n\tq = int(input())\n\t\n\tfor i in range(q):\n\t\t(x, y, k) = list(map(int, input().split()))\n\t\n\t\tif max(x, y) > k:\n\t\t\tprint(-1)\n\t\telif x == y and k == x + 1:\n\t\t\tprint(k - 2)\n\t\t\tcontinue\n\t\telif x % 2 == 1 and y % 2 == 1 and k % 2 == 0:\n\t\t\tprint(k - 2)\n\t\t\tcontinue\n\t\telif x % 2 == 0 and y % 2 == 0 and k % 2 == 1:\n\t\t\tprint(k - 2)\n\t\t\tcontinue\n\t\telif (x + y) % 2 == 0:\n\t\t\tprint(k)\n\t\telse:\n\t\t\tprint(k - 1)\n\t\n", "result": -1}]
Binary file not shown.
1 change: 1 addition & 0 deletions data/APPS/train/0111/baseline_solutions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"code": "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\nstdin = sys.stdin\nstdout = sys.stdout\ndef code():\n\tn, m = map(int, input().split())\n\t\n\tdata = []\n\tfor i in range(n):\n\t\tdata.append(input())\n\t\n\tcount_1 = [0] * n\n\tfor i, d in enumerate(data):\n\t\tcount_1[i] += d.count('1')\n\t\n\tcurrent_max = count_1[0]\n\tcount_1 = [0] * n\n\tfor i in range(1, n):\n\t\tcount_1[i] += count_1[i-1]\n\t\tif count_1[i] > current_max:\n\t\t\tcurrent_max = count_1[i]\n\t\n\toutput = len(data) * (current_max - 1)\n\tfor i in range(0, current_max):\n\t\toutput += min((i + k - 1) * count_1[i] for i in range(0, current_max))\n\tprint(output)\n\t\n", "result": -1}]
Binary file not shown.
1 change: 1 addition & 0 deletions data/APPS/train/2287/baseline_solutions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"code": "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\nimport sys\nstdin = sys.stdin\nstdout = sys.stdout\ndef code():\n\tinput = sys.stdin.readline\n\tsys.setrecursionlimit(10**6)\n\t\n\tN = int(input())\n\tG = [[] for i in range(N)]\n\tfor i in range(N-1):\n\t\tA, B = map(int, input().split())\n\t\tA -= 1\n\t\tB -= 1\n\t\tG[A].append(B)\n\t\tG[B].append(A)\n\tK = int(input())\n\tL = [-10**9] * N\n\tR = [10**9] * N\n\tfor i in range(K):\n\t\tv, p = map(int, input().split())\n\t\tL[v] = p\n\t\tR[p] = v\n\t\n\t# \u89e3\u8aac\u3092\u53c2\u8003\u306b\u5b9f\u88c5\n\t# DFS\u3067N-K+1\u3092\u8fd4\u3059\n\tfrom collections import deque\n\tq = deque([(0, 0)])\n\tvisited = [-1] * N\n\tdepth = [0] * N\n\twhile q:\n\t\tv, c = q.popleft()\n\t\tc += 1\n\t\tp = L[v]\n\t\tvisited[v] = c\n\t\tfor w in G[v]:\n\t\t\tif visited[w] == -1:\n\t\t\t\tq.append((w, c))\n\t\t\tif depth[w] > depth[v] + 1:\n\t\t\t\tdepth[w] = depth[v] + 1\n\t\t\t\tq.append((w, c))\n\t\n\tdef DFS(v, p):\n\t\tfor w in G\n", "result": -2}]
1 change: 1 addition & 0 deletions data/APPS/train/2287/gen_solutions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"code": "import sys\ninput = sys.stdin.readline\nsys.setrecursionlimit(10**6)\n\nN = int(input())\nG = [[] for i in range(N)]\nfor i in range(N-1):\n\tA, B = [int(i) for i in input().split()]\n\tA -= 1\n\tB -= 1\n\tG[A].append(B)\n\tG[B].append(A)\nK = int(input())\nL = [0] * N\nfor i in range(K):\n\tV, P = [int(i) for i in input().split()]\n\tL[V-1] += P\n\n# check\nstack = [0]\ndepth = [-1] * N\ndepth[0] = 0\nwhile stack:\n\tv = stack[-1]\n\tfor i in G[v]:\n\t\tif depth[i] == -1:\n\t\t\tdepth[i] = depth[v] + 1\n\t\t\tstack.append(i)\nif len(set(L[1:])) == N:\n\tprint(\"Yes\")\nelse:\n\tprint(\"No\")\n", "result": -1, "error_type": "timeoutexception"}, {"code": "import sys\ninput = sys.stdin.readline\nsys.setrecursionlimit(10**6)\n\nN = int(input())\nG = [[] for _ in range(N+1)]\nfor _ in range(N-1):\n\tA, B = map(int, input().split())\n\tG[A].append(B)\n\tG[B].append(A)\nK = int(input())\nL = [-float(\"inf\")] * (N+1)\nR = [float(\"inf\")] * (N+1)\nfor _ in range(K):\n\tV, P = map(int, input().split())\n\tL[V] = P\n\tR[V] = P\n\n# \u89aa\u306eL[v]\u3068R[v]\u3092\u305d\u308c\u305e\u308csorted([R[v], L[v]])\u306b\u66f4\u65b0\u3057\u3066\u304b\u3089dfs\u3057\u3066\u307f\u308b\ndef dfs(v):\n\tfor w in G[v]:\n\t\tif w == R[v] or w == L[v]:\n\t\t\tcontinue\n\t\ttmp = w\n\t\tw = R[v] if L[v] >= w else L[v]\n\t\tL[w] = w\n\t\tdfs(w)\n\ndfs(1)\nfor i in range(1, N+1):\n\tprint(i, end=\"\")\n\tif i == R[1]:\n\t\tprint(1)\n\telse:\n\t\tprint(0)\n", "result": -1, "error_type": "typeerror"}, {"code": "import sys\ninput = sys.stdin.readline\nsys.setrecursionlimit(10**6)\n\nN = int(input())\nE = [[] for _ in range(N+1)]\nfor _ in range(N-1):\n\ta, b = list(map(int, input().split()))\n\tE[a].append(b)\n\tE[b].append(a)\nK = int(input())\nL = [-float(\"inf\")] * (N+1)\nR = [float(\"inf\")] * (N+1)\nfor _ in range(K):\n\tv, p = list(map(int, input().split()))\n\tL[v] = p\n\tR[v] = p\n\ntour = []\ndef dfs(v, p):\n\ttour.append(v)\n\tfor u in E[v]:\n\t\tif u!= p:\n\t\t\tdfs(u, v)\n\t\t\ttour.append(v)\nv = tour[0]\ndfs(v, v)\nl, r = L[v], R[v]\nif l == r:\n\tprint(\"No\")\n\treturn\nif l - r <= 2:\n\tprint(\"Yes\")\n\tfor u in E[v]:\n\t\tprint((u - p) % 2)\nelse:\n\tprint(\"No\")\n", "result": -1, "error_type": "indexerror"}, {"code": "import sys\ninput = sys.stdin.readline\nsys.setrecursionlimit(10**6)\n\nN = int(input())\nG = [[] for i in range(N)]\nfor i in range(N-1):\n\tA, B = map(int, input().split())\n\tA -= 1\n\tB -= 1\n\tG[A].append(B)\n\tG[B].append(A)\nK = int(input())\nL = [-10**9] * N\nR = [10**9] * N\nfor i in range(K):\n\tv, p = map(int, input().split())\n\tL[v] = p\n\tR[p] = v\n\nimport heapq\n\nQ = [(-L[i], i) for i in range(N)]\nheapq.heapify(Q)\nflag = [0] * N\nans = [0] * N\n\nwhile Q:\n\td, v = heapq.heappop(Q)\n\tif flag[v]:\n\t\tcontinue\n\tflag[v] = 1\n\tans[v] = d - 1\n\tfor w in G[v]:\n\t\tif flag[w]:\n\t\t\tcontinue\n\t\theapq.heappush(Q, (-(L[w] - L[v] + 1), w))\n\nprint('\\n'.join(map(str, ans)))\n", "result": -1, "error_type": "indexerror"}]
Binary file not shown.
1 change: 1 addition & 0 deletions data/APPS/train/2287/input_output.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"inputs": ["5\n1 2\n3 1\n4 3\n3 5\n2\n2 6\n5 7\n", "5\n1 2\n3 1\n4 3\n3 5\n3\n2 6\n4 3\n5 7\n", "4\n1 2\n2 3\n3 4\n1\n1 0\n"], "outputs": ["Yes\n5\n6\n6\n5\n7\n", "No\n", "Yes\n0\n-1\n-2\n-3\n"]}
1 change: 1 addition & 0 deletions data/APPS/train/2287/metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"difficulty": "competition", "url": "https://atcoder.jp/contests/arc063/tasks/arc063_c"}
59 changes: 59 additions & 0 deletions data/APPS/train/2287/question.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
We have a tree with N vertices. The vertices are numbered 1, 2, ..., N. The i-th (1 ≦ i ≦ N - 1) edge connects the two vertices A_i and B_i.
Takahashi wrote integers into K of the vertices. Specifically, for each 1 ≦ j ≦ K, he wrote the integer P_j into vertex V_j. The remaining vertices are left empty. After that, he got tired and fell asleep.
Then, Aoki appeared. He is trying to surprise Takahashi by writing integers into all empty vertices so that the following condition is satisfied:
- Condition: For any two vertices directly connected by an edge, the integers written into these vertices differ by exactly 1.
Determine if it is possible to write integers into all empty vertices so that the condition is satisfied. If the answer is positive, find one specific way to satisfy the condition.

-----Constraints-----
- 1 ≦ N ≦ 10^5
- 1 ≦ K ≦ N
- 1 ≦ A_i, B_i ≦ N (1 ≦ i ≦ N - 1)
- 1 ≦ V_j ≦ N (1 ≦ j ≦ K) (21:18, a mistake in this constraint was corrected)
- 0 ≦ P_j ≦ 10^5 (1 ≦ j ≦ K)
- The given graph is a tree.
- All v_j are distinct.

-----Input-----
The input is given from Standard Input in the following format:
N
A_1 B_1
A_2 B_2
:
A_{N-1} B_{N-1}
K
V_1 P_1
V_2 P_2
:
V_K P_K

-----Output-----
If it is possible to write integers into all empty vertices so that the condition is satisfied, print Yes. Otherwise, print No.
If it is possible to satisfy the condition, print N lines in addition. The v-th (1 ≦ v ≦ N) of these N lines should contain the integer that should be written into vertex v. If there are multiple ways to satisfy the condition, any of those is accepted.

-----Sample Input-----
5
1 2
3 1
4 3
3 5
2
2 6
5 7

-----Sample Output-----
Yes
5
6
6
5
7

The figure below shows the tree when Takahashi fell asleep. For each vertex, the integer written beside it represents the index of the vertex, and the integer written into the vertex is the integer written by Takahashi.
Aoki can, for example, satisfy the condition by writing integers into the remaining vertices as follows:
This corresponds to Sample Output 1. Note that other outputs that satisfy the condition will also be accepted, such as:
Yes
7
6
8
7
7
1 change: 1 addition & 0 deletions data/APPS/train/2287/solutions.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions data/APPS/train/2404/baseline_solutions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"code": "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\nstdin = sys.stdin\nstdout = sys.stdout\ndef code():\n\tclass Solution:\n\t\tdef findKthPositive(self, arr: List[int], k: int) -> int:\n\t\t\t\n\t\t\tcounter = set()\n\t\t\tfor i in arr:\n\t\t\t\tcounter.add(i)\n\t\t\t\n\t\t\tresult = []\n\t\t\t\n\t\t\tposition = 1\n\t\t\twhile k > 0:\n\t\t\t\tif position not in counter:\n\t\t\t\t\tk -= 1 \n\t\t\t\tposition += 1\n\t\t\t\t\n\t\t\treturn position - 1\n\t\t\t\t\n\t\t\t\t\n\t\n", "result": true}]
Loading

0 comments on commit 51db4ff

Please sign in to comment.