Skip to content

Commit

Permalink
Update docs, bump stim to 1.10 (oscarhiggott#55)
Browse files Browse the repository at this point in the history
* decode_batch test

* Update benchmarks using decode_batch

* Bump stim C++ dependency to v1.10.0. Fixes oscarhiggott#54

* Update docs for drawing

* Update toric code example notebook

* Update stim C++ dependency to latest commit

* Add b8 test

* Fix cli test

* Add surface code b8 test

* Binary read/write

* stim 1.10.0

* include obs

* update workflow

Co-authored-by: Oscar Higgott <[email protected]>
  • Loading branch information
oscarhiggott and oscarhiggott authored Jan 18, 2023
1 parent 97e0d0a commit e4a451c
Show file tree
Hide file tree
Showing 18 changed files with 301 additions and 130 deletions.
36 changes: 18 additions & 18 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -246,21 +246,21 @@ jobs:
run: pytest tests --cov=./src/pymatching --cov-report=xml
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
upload_all:
name: Upload to pypi
needs: [build_wheels, build_sdist]
runs-on: ubuntu-latest
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags/v')
steps:
- uses: actions/setup-python@v4
with:
python-version: "3.9"

- uses: actions/download-artifact@v3
with:
name: artifact
path: dist

- uses: pypa/[email protected]
with:
password: ${{ secrets.pypi_password }}
# upload_all:
# name: Upload to pypi
# needs: [build_wheels, build_sdist]
# runs-on: ubuntu-latest
# if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags/v')
# steps:
# - uses: actions/setup-python@v4
# with:
# python-version: "3.9"
#
# - uses: actions/download-artifact@v3
# with:
# name: artifact
# path: dist
#
# - uses: pypa/[email protected]
# with:
# password: ${{ secrets.pypi_password }}
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ FetchContent_MakeAvailable(googletest)

FetchContent_Declare(stim
GIT_REPOSITORY https://github.com/quantumlib/stim.git
GIT_TAG v1.9.0)
GIT_TAG v1.10.0)
FetchContent_GetProperties(stim)
if (NOT stim_POPULATED)
FetchContent_Populate(stim)
Expand Down
13 changes: 5 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ The new version is also exact - unlike previous versions of PyMatching, no appro

Our new implementation is **over 100x faster** than previous versions of PyMatching, and is
**over 100,000x faster** than NetworkX (benchmarked with surface code circuits). At 0.1% circuit-noise, PyMatching can
decode both X and Z basis measurements of surface code circuits up to distance 13 in under 1 microsecond per round
of syndrome extraction on a single core (or up to distance 19 if only X-basis measurements are processed - however
both X and Z basis measurements must be decoded at scale). Furthermore, the runtime is roughly linear in the number
decode both X and Z basis measurements of surface code circuits up to distance 17 in under 1 microsecond per round
of syndrome extraction on a single core. Furthermore, the runtime is roughly linear in the number
of nodes in the graph.

The plot below compares the performance of PyMatching v2 with the previous
Expand Down Expand Up @@ -127,12 +126,10 @@ Now we can decode! We compare PyMatching's predictions of the logical observable
with stim, in order to count the number of mistakes and estimate the logical error rate:

```python
num_errors = 0
for i in range(syndrome.shape[0]):
predicted_observables = matching.decode(syndrome[i, :])
num_errors += not np.array_equal(actual_observables[i, :], predicted_observables)
predicted_observables = matching.decode_batch(syndrome)
num_errors = np.sum(np.any(predicted_observables != actual_observables, axis=1))

print(num_errors) # prints 8
print(num_errors) # prints 5
```

### Loading from a parity check matrix
Expand Down
45 changes: 40 additions & 5 deletions benchmarks/surface_codes/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,38 @@ for a given error rate `p` and code distance `distance` in `{5, 7, 9, 13, 17, 23
From a stim circuit `circuit`, the corresponding `stim.DetectorErrorModel` used to configure the decoder can be
generated using `circuit.detector_error_model(decompose_errors=True)`.

Then, using stim to generate some samples in b8 format (in this case with appended observables), the time per shot
in microseconds for pymatching 2 was measured (on an M1 Max processor) using the pymatching command line tool, where here
`$samples_fn` is the filename of the b8 samples file and `$dem_fn` is the filename of the detector error model:

```shell
pymatching count_mistakes --in $samples_fn --in_format b8 --dem $dem_fn --in_includes_appended_observables --time 2>&1 >/dev/null | sed -n -E 's/Decoding time per shot: (.+)us/\1/p'
Then, using stim to generate some samples, the time per shot
in microseconds for pymatching 2 was measured (on an M1 Max processor) by running `pymatching.Matching.decode_batch` on
at least 10000 shots. For example, the number of microseconds per shot can be measured using the following function:
```python
import time
import stim
import pymatching


def time_surface_code_circuit(distance: int, p: float, num_shots: int = 10000) -> float:
circuit = stim.Circuit.generated(
"surface_code:rotated_memory_x",
distance=distance,
rounds=distance,
after_clifford_depolarization=p,
before_round_data_depolarization=p,
before_measure_flip_probability=p,
after_reset_flip_probability=p
)
dem = circuit.detector_error_model(decompose_errors=True)
matching = pymatching.Matching.from_detector_error_model(dem)
sampler = circuit.compile_detector_sampler()
shots, actual_observables = sampler.sample(shots=num_shots, separate_observables=True)
# Decode one shot first to ensure internal C++ representation of the matching graph is fully cached
matching.decode_batch(shots[0:1, :])
# Now time decoding the batch
t0 = time.time()
matching.decode_batch(shots)
t1 = time.time()
microseconds_per_shot = 1e6*(t1-t0)/num_shots
return microseconds_per_shot
```

In the figure in each subdirectory,
Expand All @@ -46,3 +72,12 @@ must be done by the decoder within the wider context of a fault-tolerant quantum
where the basis of the logical measurement is not always known apriori (both X and Z logical operators must be
preserved) and the X and Z matching graphs can become connected (such as when implementing a logical S gate in the
surface code by [braiding twist defects](https://arxiv.org/abs/1609.04673)).

If you've looked at the benchmarks in this repository before, you may have noticed that there have been two previous
versions of the data. In the first version,
only the X basis was decoded, which did not fully represent the work required to decode a surface code at scale,
as described above. In the second version, both bases were decoded (and since the problems became 2x bigger,
the time per round also doubled). However, for both the first and second versions, the timing data was collected by
decoding shot data from file using the pymatching command line interface. At low p (e.g. around 0.1%), it turned out
that almost half the time was spent reading the shot data from file. So in the current version the shot data is
decoded in a batch from memory (see above), with both bases still decoded.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
d,p,microseconds
5, 0.001,0.7137
7, 0.001,1.8224
9, 0.001,3.9624
13, 0.001,11.7847
17, 0.001,25.6879
23, 0.001,68.0293
29, 0.001,150.337
39, 0.001,436.855
50, 0.001,1060.09
5, 0.001,0.388141
7, 0.001,1.05477
9, 0.001,2.31456
13, 0.001,6.96616
17, 0.001,16.0062
23, 0.001,45.1341
29, 0.001,102.495
39, 0.001,310.649
50, 0.001,778.247
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
d,p,microseconds
5, 0.005,3.1209
7, 0.005,7.2927
9, 0.005,16.3826
13, 0.005,56.8005
17, 0.005,146.022
23, 0.005,391.382
29, 0.005,878.586
39, 0.005,2687.91
50, 0.005,7183.41
5, 0.005,1.83938
7, 0.005,6.13618
9, 0.005,14.5596
13, 0.005,50.9597
17, 0.005,125.17
23, 0.005,338.516
29, 0.005,718.097
39, 0.005,2349.41
50, 0.005,6834.5
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
d,p,microseconds
5, 0.008,4.8573
7, 0.008,12.9679
9, 0.008,33.1905
13, 0.008,128.431
17, 0.008,337.928
23, 0.008,947.958
29, 0.008,2268.38
39, 0.008,7801.89
50, 0.008,24642.2
5, 0.008,6.84748
7, 0.008,13.299
9, 0.008,30.117
13, 0.008,117.831
17, 0.008,317.935
23, 0.008,935.859
29, 0.008,2185.42
39, 0.008,7447.51
50, 0.008,23116.7
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
d,p,microseconds
5, 0.01,5.4227
7, 0.01,17.2281
9, 0.01,44.5062
13, 0.01,185.097
17, 0.01,504.902
23, 0.01,1483.5
29, 0.01,3549.17
39, 0.01,12820.4
50, 0.01,40346.1
5, 0.01,9.1722
7, 0.01,17.3477
9, 0.01,42.1067
13, 0.01,171.797
17, 0.01,469.216
23, 0.01,1460.07
Binary file added data/three_errors.b8
Binary file not shown.
218 changes: 169 additions & 49 deletions docs/toric-code-example.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/pymatching/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,10 +570,11 @@ def decode_to_matched_dets_dict(self,

def draw(self) -> None:
"""Draw the matching graph using matplotlib
Draws the matching graph as a matplotlib graph. Stabiliser nodes are
Draws the matching graph as a matplotlib graph. Detector nodes are
filled grey and boundary nodes are filled white. The line thickness of each
edge is determined from its weight (with min and max thicknesses of 0.2 pts
and 2 pts respectively).
Each node is labelled with its id/index, and each edge is labelled with its `fault_ids`.
Note that you may need to call `plt.figure()` before and `plt.show()` after calling
this function.
"""
Expand Down
10 changes: 5 additions & 5 deletions src/pymatching/sparse_blossom/driver/namespaced_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ int main_predict(int argc, const char **argv) {
argc,
argv);

FILE *shots_in = stim::find_open_file_argument("--in", stdin, "r", argc, argv);
FILE *predictions_out = stim::find_open_file_argument("--out", stdout, "w", argc, argv);
FILE *shots_in = stim::find_open_file_argument("--in", stdin, "rb", argc, argv);
FILE *predictions_out = stim::find_open_file_argument("--out", stdout, "wb", argc, argv);
FILE *dem_file = stim::find_open_file_argument("--dem", nullptr, "r", argc, argv);
stim::FileFormatData shots_in_format =
stim::find_enum_argument("--in_format", "b8", stim::format_name_to_enum_map, argc, argv);
Expand Down Expand Up @@ -100,9 +100,9 @@ int main_count_mistakes(int argc, const char **argv) {
argc,
argv);

FILE *shots_in = stim::find_open_file_argument("--in", stdin, "r", argc, argv);
FILE *obs_in = stim::find_open_file_argument("--obs_in", stdin, "r", argc, argv);
FILE *stats_out = stim::find_open_file_argument("--out", stdout, "w", argc, argv);
FILE *shots_in = stim::find_open_file_argument("--in", stdin, "rb", argc, argv);
FILE *obs_in = stim::find_open_file_argument("--obs_in", stdin, "rb", argc, argv);
FILE *stats_out = stim::find_open_file_argument("--out", stdout, "wb", argc, argv);
FILE *dem_file = stim::find_open_file_argument("--dem", nullptr, "r", argc, argv);
stim::FileFormatData shots_in_format =
stim::find_enum_argument("--in_format", "01", stim::format_name_to_enum_map, argc, argv);
Expand Down
34 changes: 28 additions & 6 deletions tests/cli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@


@contextmanager
def three_errors_data():
def three_errors_data(in_fmt: str):
out_fn = os.path.join(DATA_DIR, "three_errors_predictions.dets")
if os.path.isfile(out_fn):
os.remove(out_fn)
assert not os.path.exists(out_fn)
args = [
"predict",
"--dem", os.path.join(DATA_DIR, "three_errors.dem"),
"--in", os.path.join(DATA_DIR, "three_errors.dets"),
"--in_format", "dets",
"--in", os.path.join(DATA_DIR, "three_errors." + in_fmt),
"--in_format", in_fmt,
"--out", out_fn,
"--out_format", "dets",
]
Expand All @@ -35,17 +35,39 @@ def three_errors_data():


def test_cli():
with three_errors_data() as args:
with three_errors_data("dets") as args:
cli(command_line_args=args)


def test_protected_cli():
with three_errors_data() as args:
with three_errors_data("dets") as args:
pymatching._cpp_pymatching.main(command_line_args=args)


def test_protected_cli_b8_in():
with three_errors_data("b8") as args:
pymatching._cpp_pymatching.main(command_line_args=args)


def test_cli_argv():
from unittest.mock import patch
with three_errors_data() as args:
with three_errors_data("dets") as args:
with patch.object(sys, 'argv', ["cli"] + args):
cli_argv()


def test_load_surface_code_b8_cli():
dem_path = os.path.join(DATA_DIR, "surface_code_rotated_memory_x_13_0.01.dem")
dets_b8_in_path = os.path.join(DATA_DIR, "surface_code_rotated_memory_x_13_0.01_1000_shots.b8")
out_fn = os.path.join(DATA_DIR, "surface_code_rotated_memory_x_13_0.01_1000_shots_temp_predictions.b8")

pymatching._cpp_pymatching.main(command_line_args=[
"predict",
"--dem", dem_path,
"--in", dets_b8_in_path,
"--in_format", "b8",
"--out", out_fn,
"--out_format", "b8",
"--in_includes_appended_observables"
])
os.remove(out_fn)

0 comments on commit e4a451c

Please sign in to comment.