Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions .github/workflows/parity.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Parity check

on:
push:
pull_request:

jobs:
parity:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"

- name: Install deps (CPU only)
run: |
python -m pip install --upgrade pip
pip install --index-url https://download.pytorch.org/whl/cpu torch
pip install numpy "jax[cpu]" dm-haiku
pip install git+https://github.com/google-deepmind/tracr.git

- name: Compile & export tracr
run: python scripts/compile_export.py

- name: Verify parity (fail if mismatch)
run: python scripts/parity_check.py
Binary file modified __pycache__/tracr_transformer_pt.cpython-313.pyc
Binary file not shown.
1 change: 1 addition & 0 deletions artifacts/input_tokens.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
["BOS", 1, 0, 1, 1, 0]
6 changes: 6 additions & 0 deletions artifacts/token_to_id.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"BOS": 2,
"0": 0,
"1": 1,
"PAD": 3
}
Binary file added artifacts/tracr_majority_params.npz
Binary file not shown.
Binary file added artifacts/tracr_output.npy
Binary file not shown.
40 changes: 40 additions & 0 deletions artifacts/tracr_param_keys.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
[
"pos_embed__embeddings",
"token_embed__embeddings",
"transformer__layer_0__attn__key__b",
"transformer__layer_0__attn__key__w",
"transformer__layer_0__attn__linear__b",
"transformer__layer_0__attn__linear__w",
"transformer__layer_0__attn__query__b",
"transformer__layer_0__attn__query__w",
"transformer__layer_0__attn__value__b",
"transformer__layer_0__attn__value__w",
"transformer__layer_0__mlp__linear_1__b",
"transformer__layer_0__mlp__linear_1__w",
"transformer__layer_0__mlp__linear_2__b",
"transformer__layer_0__mlp__linear_2__w",
"transformer__layer_1__attn__key__b",
"transformer__layer_1__attn__key__w",
"transformer__layer_1__attn__linear__b",
"transformer__layer_1__attn__linear__w",
"transformer__layer_1__attn__query__b",
"transformer__layer_1__attn__query__w",
"transformer__layer_1__attn__value__b",
"transformer__layer_1__attn__value__w",
"transformer__layer_1__mlp__linear_1__b",
"transformer__layer_1__mlp__linear_1__w",
"transformer__layer_1__mlp__linear_2__b",
"transformer__layer_1__mlp__linear_2__w",
"transformer__layer_2__attn__key__b",
"transformer__layer_2__attn__key__w",
"transformer__layer_2__attn__linear__b",
"transformer__layer_2__attn__linear__w",
"transformer__layer_2__attn__query__b",
"transformer__layer_2__attn__query__w",
"transformer__layer_2__attn__value__b",
"transformer__layer_2__attn__value__w",
"transformer__layer_2__mlp__linear_1__b",
"transformer__layer_2__mlp__linear_1__w",
"transformer__layer_2__mlp__linear_2__b",
"transformer__layer_2__mlp__linear_2__w"
]
60 changes: 0 additions & 60 deletions export_tracr_params.py

This file was deleted.

366 changes: 0 additions & 366 deletions graph.gv

This file was deleted.

Binary file removed graph.gv.pdf
Binary file not shown.
1 change: 1 addition & 0 deletions input_tokens.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
["BOS", 1, 0, 1, 1, 0]
70 changes: 0 additions & 70 deletions load_and_visualize_with_torchlens.py

This file was deleted.

183 changes: 0 additions & 183 deletions my_majority_program.py

This file was deleted.

Loading