diff --git a/qualtran/_infra/bloq.py b/qualtran/_infra/bloq.py index 8572718445..61639454d2 100644 --- a/qualtran/_infra/bloq.py +++ b/qualtran/_infra/bloq.py @@ -60,7 +60,7 @@ GeneralizerT, SympySymbolAllocator, ) - from qualtran.simulation.classical_sim import ClassicalValRetT, ClassicalValT + from qualtran.simulation.classical_sim import ClassicalValRetT, ClassicalValT, MeasurementPhase from qualtran.simulation.tensor import DiscardInd @@ -220,7 +220,9 @@ def on_classical_vals( except NotImplementedError as e: raise NotImplementedError(f"{self} does not support classical simulation: {e}") from e - def basis_state_phase(self, **vals: 'ClassicalValT') -> Union[complex, None]: + def basis_state_phase( + self, **vals: 'ClassicalValT' + ) -> Union[complex, 'MeasurementPhase', None]: """How this bloq phases classical basis states. Override this method if your bloq represents classical logic with basis-state @@ -231,7 +233,8 @@ def basis_state_phase(self, **vals: 'ClassicalValT') -> Union[complex, None]: (X, CNOT, Toffoli, ...) and diagonal operations (T, CZ, CCZ, ...). Bloq authors should override this method. If you are using an instantiated bloq object, - call TODO and not this method directly. + call `qualtran.simulation.classical_sim.do_phased_classical_simulation` or use + `qualtran.simulation.classical_sim.PhasedClassicalSimState`. If this method is implemented, `on_classical_vals` must also be implemented. If `on_classical_vals` is implemented but this method is not implemented, it is assumed diff --git a/qualtran/_infra/composite_bloq.py b/qualtran/_infra/composite_bloq.py index 9832b9ffaa..1d3f91e0b4 100644 --- a/qualtran/_infra/composite_bloq.py +++ b/qualtran/_infra/composite_bloq.py @@ -512,6 +512,42 @@ def _binst_to_cxns( return pred_cxns, succ_cxns +def _get_soquet( + binst: 'BloqInstance', + reg_name: str, + right: bool = False, + idx: Tuple[int, ...] = (), + *, + binst_graph: nx.DiGraph, +) -> 'Soquet': + """Retrieve a soquet given identifying information. + + We can uniquely address a Soquet by the arguments to this function. + + Args: + binst: The bloq instance associated with the desired soquet. + reg_name: The name of the register associated with the desired soquet. + right: If False, get the input, left soquet. Otherwise: the right, output soquet + idx: The index of the soquet within a multidimensional register, or the empty + tuple for basic registers. + """ + preds, succs = _binst_to_cxns(binst, binst_graph=binst_graph) + if right: + for suc in succs: + me = suc.left + if me.reg.name == reg_name and me.idx == idx: + return me + else: + for pred in preds: + me = pred.right + if me.reg.name == reg_name and me.idx == idx: + return me + + raise ValueError( + f"Could not find the requested soquet with {binst=}, {reg_name=}, {right=}, {idx=}" + ) + + def _cxns_to_soq_dict( regs: Iterable[Register], cxns: Iterable[Connection], diff --git a/qualtran/_infra/controlled.py b/qualtran/_infra/controlled.py index 0e99500728..cac91d441d 100644 --- a/qualtran/_infra/controlled.py +++ b/qualtran/_infra/controlled.py @@ -47,7 +47,7 @@ from qualtran.cirq_interop import CirqQuregT from qualtran.drawing import WireSymbol from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator - from qualtran.simulation.classical_sim import ClassicalValRetT, ClassicalValT + from qualtran.simulation.classical_sim import ClassicalValRetT, ClassicalValT, MeasurementPhase ControlBit: TypeAlias = int """A control bit, either 0 or 1.""" @@ -380,10 +380,7 @@ def ctrl_spec(self) -> 'CtrlSpec': @cached_property def _thru_registers_only(self) -> bool: - for reg in self.subbloq.signature: - if reg.side != Side.THRU: - return False - return True + return self.signature.thru_registers_only @staticmethod def _make_ctrl_system(cb: '_ControlledBase') -> Tuple['_ControlledBase', 'AddControlledT']: @@ -453,7 +450,9 @@ def on_classical_vals(self, **vals: 'ClassicalValT') -> Mapping[str, 'ClassicalV return vals - def basis_state_phase(self, **vals: 'ClassicalValT') -> Union[complex, None]: + def basis_state_phase( + self, **vals: 'ClassicalValT' + ) -> Union[complex, 'MeasurementPhase', None]: """Phasing action of controlled bloqs. This involves conditionally doing the phasing action of `subbloq`. All implementers @@ -533,7 +532,15 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) - from qualtran.drawing import Text if reg is None: - return Text(f'C[{self.subbloq}]') + sub_title = self.subbloq.wire_symbol(None, idx) + if not isinstance(sub_title, Text): + raise ValueError( + f"{self.subbloq} should return a `Text` object for reg=None wire symbol." + ) + if sub_title.text == '': + return Text('') + + return Text(f'C[{sub_title.text}]') if reg.name not in self.ctrl_reg_names: # Delegate to subbloq return self.subbloq.wire_symbol(reg, idx) @@ -688,6 +695,7 @@ def make_ctrl_system_with_correct_metabloq( `ControlledViaAnd`, which computes the activation function once and re-uses it for each subbloq in the decomposition of `bloq`. """ + from qualtran.bloqs.mcmt.classically_controlled import ClassicallyControlled from qualtran.bloqs.mcmt.controlled_via_and import ControlledViaAnd if ctrl_spec == CtrlSpec(): @@ -710,6 +718,6 @@ def make_ctrl_system_with_correct_metabloq( if qdtypes: return ControlledViaAnd.make_ctrl_system(bloq, ctrl_spec=ctrl_spec) if cdtypes: - raise NotImplementedError("Stay tuned...") + return ClassicallyControlled.make_ctrl_system(bloq, ctrl_spec=ctrl_spec) raise ValueError(f"Invalid control spec: {ctrl_spec}") diff --git a/qualtran/_infra/registers.py b/qualtran/_infra/registers.py index bfeb8578d0..daf8ae3c9f 100644 --- a/qualtran/_infra/registers.py +++ b/qualtran/_infra/registers.py @@ -16,6 +16,7 @@ import enum import itertools from collections import defaultdict +from functools import cached_property from typing import cast, Dict, Iterable, Iterator, List, overload, Tuple, Union import attrs @@ -179,6 +180,13 @@ def build_from_dtypes(cls, **registers: QCDType) -> 'Signature': """ return cls(Register(name=k, dtype=v) for k, v in registers.items() if v.num_qubits) + @cached_property + def thru_registers_only(self) -> bool: + for reg in self: + if reg.side != Side.THRU: + return False + return True + def lefts(self) -> Iterable[Register]: """Iterable over all registers that appear on the LEFT as input.""" yield from self._lefts.values() diff --git a/qualtran/bloqs/basic_gates/__init__.py b/qualtran/bloqs/basic_gates/__init__.py index d89e5080a9..a1ce524e9e 100644 --- a/qualtran/bloqs/basic_gates/__init__.py +++ b/qualtran/bloqs/basic_gates/__init__.py @@ -34,7 +34,7 @@ from .swap import CSwap, Swap, TwoBitCSwap, TwoBitSwap from .t_gate import TGate from .toffoli import Toffoli -from .x_basis import MinusEffect, MinusState, PlusEffect, PlusState, XGate +from .x_basis import MeasX, MinusEffect, MinusState, PlusEffect, PlusState, XGate from .y_gate import CYGate, YGate from .z_basis import ( CZ, diff --git a/qualtran/bloqs/basic_gates/x_basis.py b/qualtran/bloqs/basic_gates/x_basis.py index 5b50f5da8d..f6e8dafc31 100644 --- a/qualtran/bloqs/basic_gates/x_basis.py +++ b/qualtran/bloqs/basic_gates/x_basis.py @@ -24,6 +24,7 @@ bloq_example, BloqBuilder, BloqDocSpec, + CBit, ConnectionT, CtrlSpec, QBit, @@ -33,6 +34,12 @@ SoquetT, ) from qualtran.drawing import directional_text_box, Text, WireSymbol +from qualtran.simulation.classical_sim import ( + ClassicalValDistribution, + ClassicalValRetT, + ClassicalValT, + MeasurementPhase, +) if TYPE_CHECKING: import cirq @@ -41,7 +48,6 @@ from pennylane.wires import Wires from qualtran.cirq_interop import CirqQuregT - from qualtran.simulation.classical_sim import ClassicalValT _PLUS = np.ones(2, dtype=np.complex128) / np.sqrt(2) _MINUS = np.array([1, -1], dtype=np.complex128) / np.sqrt(2) @@ -85,7 +91,10 @@ def my_tensors( ] def as_cirq_op( - self, qubit_manager: 'cirq.QubitManager', **cirq_quregs: 'CirqQuregT' # type: ignore[type-var] + self, + qubit_manager: 'cirq.QubitManager', + **cirq_quregs: 'CirqQuregT', + # type: ignore[type-var] ) -> Tuple[Union['cirq.Operation', None], Dict[str, 'CirqQuregT']]: # type: ignore[type-var] if not self.state: raise ValueError(f"There is no Cirq equivalent for {self}") @@ -270,3 +279,47 @@ def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSym return Text('X') return ModPlus() + + def __str__(self): + return 'X' + + +@frozen +class MeasX(Bloq): + @cached_property + def signature(self) -> 'Signature': + return Signature( + [Register('q', QBit(), side=Side.LEFT), Register('c', CBit(), side=Side.RIGHT)] + ) + + def on_classical_vals(self, q: int) -> Dict[str, 'ClassicalValRetT']: + if q not in [0, 1]: + raise ValueError(f"Invalid classical value encountered in {self}: {q}") + return {'c': ClassicalValDistribution(2)} + + def basis_state_phase(self, q: int) -> Union[complex, MeasurementPhase]: + if q == 0: + return 1 + if q == 1: + return MeasurementPhase(reg_name='c') + raise ValueError(f"Invalid classical value encountered in {self}: {q}") + + def my_tensors( + self, incoming: Dict[str, 'ConnectionT'], outgoing: Dict[str, 'ConnectionT'] + ) -> List['qtn.Tensor']: + import quimb.tensor as qtn + + from qualtran.simulation.tensor import DiscardInd + + data = np.array( + [ + [[0.5 + 0.0j, 0.5 + 0.0j], [0.5 + 0.0j, -0.5 + 0.0j]], + [[0.5 + 0.0j, -0.5 + 0.0j], [0.5 + 0.0j, 0.5 + 0.0j]], + ] + ) + + q_trace = qtn.rand_uuid('q_trace') + t = qtn.Tensor( + data=data, inds=[(incoming['q'], 0), (q_trace, 0), (outgoing['c'], 0)], tags=[str(self)] + ) + return [t, DiscardInd((q_trace, 0))] diff --git a/qualtran/bloqs/basic_gates/x_basis_test.py b/qualtran/bloqs/basic_gates/x_basis_test.py index edcde41797..6404247fc0 100644 --- a/qualtran/bloqs/basic_gates/x_basis_test.py +++ b/qualtran/bloqs/basic_gates/x_basis_test.py @@ -15,7 +15,7 @@ import numpy as np from qualtran import BloqBuilder -from qualtran.bloqs.basic_gates import MinusState, PlusEffect, PlusState, XGate +from qualtran.bloqs.basic_gates import MeasX, MinusState, PlusEffect, PlusState, XGate from qualtran.resource_counting import GateCounts, get_cost_value, QECGatesCost from qualtran.simulation.classical_sim import ( format_classical_truth_table, @@ -119,3 +119,8 @@ def _keep_and(b): bloq = XGate().controlled(CtrlSpec(qdtypes=QUInt(n), cvs=1)) _, sigma = bloq.call_graph(keep=_keep_and) assert sigma == {And(): n - 1, CNOT(): 1, And().adjoint(): n - 1, XGate(): 4 * (n - 1)} + + +def test_meas_x_classical_sim(): + m = MeasX() + m.call_classically(q=0) diff --git a/qualtran/bloqs/mcmt/and_bloq.py b/qualtran/bloqs/mcmt/and_bloq.py index 486bb3e49b..50f4bc5a5a 100644 --- a/qualtran/bloqs/mcmt/and_bloq.py +++ b/qualtran/bloqs/mcmt/and_bloq.py @@ -103,7 +103,10 @@ def on_classical_vals( return {'ctrl': ctrl, 'target': out} # Uncompute - assert target == out + if target != out: + raise ValueError( + f"Inconsistent `target` found for uncomputing `And`: {ctrl=}, {target=}. Expected target={out}" + ) return {'ctrl': ctrl} def my_tensors( diff --git a/qualtran/bloqs/mcmt/classically_controlled.py b/qualtran/bloqs/mcmt/classically_controlled.py new file mode 100644 index 0000000000..0d8c72e513 --- /dev/null +++ b/qualtran/bloqs/mcmt/classically_controlled.py @@ -0,0 +1,40 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple + +import attrs + +from qualtran import AddControlledT, Bloq, CDType, CtrlSpec, QCDType +from qualtran._infra.controlled import _ControlledBase + + +@attrs.frozen +class ClassicallyControlled(_ControlledBase): + + subbloq: 'Bloq' + ctrl_spec: 'CtrlSpec' + + def __attrs_post_init__(self): + for qcdtype in self.ctrl_spec.qdtypes: + if not isinstance(qcdtype, QCDType): + raise ValueError(f"Invalid type found in `ctrl_spec`: {qcdtype}") + if not isinstance(qcdtype, CDType): + raise ValueError(f"Invalid type found in `ctrl_spec`: {qcdtype}") + + @classmethod + def make_ctrl_system( + cls, bloq: 'Bloq', ctrl_spec: 'CtrlSpec' + ) -> Tuple['_ControlledBase', 'AddControlledT']: + cb = cls(subbloq=bloq, ctrl_spec=ctrl_spec) + return cls._make_ctrl_system(cb) diff --git a/qualtran/resource_counting/_bloq_counts.py b/qualtran/resource_counting/_bloq_counts.py index 3ed9a4a10a..d293ecd5ab 100644 --- a/qualtran/resource_counting/_bloq_counts.py +++ b/qualtran/resource_counting/_bloq_counts.py @@ -297,7 +297,15 @@ class QECGatesCost(CostKey[GateCounts]): legacy_shims: bool = False def compute(self, bloq: 'Bloq', get_callee_cost: Callable[['Bloq'], GateCounts]) -> GateCounts: - from qualtran.bloqs.basic_gates import GlobalPhase, Identity, Toffoli, TwoBitCSwap + from qualtran.bloqs.basic_gates import ( + Discard, + GlobalPhase, + Identity, + MeasX, + MeasZ, + Toffoli, + TwoBitCSwap, + ) from qualtran.bloqs.basic_gates._shims import Measure from qualtran.bloqs.bookkeeping._bookkeeping_bloq import _BookkeepingBloq from qualtran.bloqs.mcmt import And, MultiTargetCNOT @@ -326,7 +334,7 @@ def compute(self, bloq: 'Bloq', get_callee_cost: Callable[['Bloq'], GateCounts]) return GateCounts(toffoli=1) # Measurement - if isinstance(bloq, Measure): + if isinstance(bloq, (Measure, MeasZ, MeasX)): return GateCounts(measurement=1) # 'And' bloqs @@ -370,9 +378,10 @@ def compute(self, bloq: 'Bloq', get_callee_cost: Callable[['Bloq'], GateCounts]) return GateCounts() # Bookkeeping, empty bloqs - if isinstance(bloq, _BookkeepingBloq) or isinstance(bloq, (GlobalPhase, Identity)): + if isinstance(bloq, _BookkeepingBloq) or isinstance(bloq, (GlobalPhase, Identity, Discard)): return GateCounts() + # Rotations if bloq_is_rotation(bloq): return GateCounts(rotation=1) diff --git a/qualtran/simulation/MBUC.ipynb b/qualtran/simulation/MBUC.ipynb new file mode 100644 index 0000000000..73fe50fc6f --- /dev/null +++ b/qualtran/simulation/MBUC.ipynb @@ -0,0 +1,375 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "44f40baf-3f87-40c5-9bf1-409ac3f86e68", + "metadata": {}, + "source": [ + "# Verifying Measurement-Based Uncomputation\n", + "\n", + "Quantum information cannot be destroyed, but during a computation we may produce intermediate values that we wish to discard. We can \"uncompute\" these values by running the computation in reverse. The ordinary uncomputation strategy requires paying the cost of the computation twice, but [*Halving the cost of quantum addition.* Gidney 2017](https://arxiv.org/abs/1709.06648) shows how measurement in the X basis can effectively discard a bit without expensive uncomputation. The consequence is that the remaining states of the system will pick up phases depending on the random measurement result. [*Verifying Measurement Based Uncomputation.* Gidney 2019](https://algassert.com/post/1903) provides more detail about these phases. It also describes a proceedure for using a phased-classical simulator to \"fuzz test\" measurement-based uncomputation circuits.\n", + "\n", + "Here, we show how Qualtran can be used to verify measurement based uncomputation circuits following Gidney's proposal." + ] + }, + { + "cell_type": "markdown", + "id": "3c3fdd7b-a94b-4f57-bb51-eec6217018df", + "metadata": {}, + "source": [ + "## Uncomputing $\\mathrm{And}$\n", + "\n", + "As a warm-up, we can use the reference classical action of `And(uncompute=True)` to verify the truth table of the operation. First, we check the bloq over valid inputs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3dddb0a5-a2b6-44cb-beb9-79b9600aae5e", + "metadata": {}, + "outputs": [], + "source": [ + "import itertools\n", + "from qualtran.bloqs.mcmt import And\n", + "\n", + "and_dag = And(uncompute=True)\n", + "for q1, q2 in itertools.product(range(2), repeat=2):\n", + " trg = int(q1==1 and q2 == 1)\n", + " print(f'{q1=}, {q2=}, {trg=}', end=' ')\n", + " (q1o, q2o), = and_dag.call_classically(ctrl=[q1,q2], target=trg)\n", + " assert q1o == q1\n", + " assert q2o == q2\n", + " print('✓')" + ] + }, + { + "cell_type": "markdown", + "id": "33c34f76-b53c-4ef7-9d72-1e7f6ebf7725", + "metadata": {}, + "source": [ + "In a quantum computer, there is no error handling; but the classical simulation will helpfully inform you if you supply invalid inputs to the bloq. Here, there is an error because the `target` register does not contain the result of a (forwards) computation of $\\mathrm{And}$." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dae8d49d-cc11-4d1c-84f7-70765d7c621b", + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " and_dag.call_classically(ctrl=[1,1], target=0)\n", + "except ValueError as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "id": "c82cfd5d-8482-496c-8084-0abeb315a9ca", + "metadata": {}, + "source": [ + "## Naive attempt at $\\mathrm{And}^\\dagger$\n", + "\n", + "What happens if we just measure the target bit in the X basis and throw it away? We'll build this simple circuit below so we can use the phased-classical simulator to find out." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a76a727-a742-422b-a8ac-e588f00fe765", + "metadata": {}, + "outputs": [], + "source": [ + "from qualtran import BloqBuilder, Register, QBit, Side, Controlled, CtrlSpec, CBit\n", + "from qualtran.bloqs.basic_gates import MeasX, Discard, CZ\n", + "\n", + "bb = BloqBuilder()\n", + "q1 = bb.add_register('q1', 1)\n", + "q2 = bb.add_register('q2', 1)\n", + "trg = bb.add_register(Register('trg', QBit(), side=Side.LEFT))\n", + "\n", + "ctrg = bb.add(MeasX(), q=trg)\n", + "bb.add(Discard(), c=ctrg)\n", + "\n", + "throw_out_target = bb.finalize(q1=q1, q2=q2)\n", + "from qualtran.drawing import show_bloq\n", + "show_bloq(throw_out_target, 'musical_score')" + ] + }, + { + "cell_type": "markdown", + "id": "ca44fb5b-e961-434e-af6d-5deaa4fa984d", + "metadata": {}, + "source": [ + "## Fuzz testing measurement circuits\n", + "\n", + "Given a computational basis state input, the X-basis measurement operation returns a random outcome. We explicitly supply a random number generator to the phased classical simulation function to support these circuits.\n", + "\n", + "Since our simulation is now stochastic, we run it 10 times and see if we get the right answer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d123b5b4-dc1c-4957-b214-e4b9d41bb700", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from qualtran.simulation.classical_sim import do_phased_classical_simulation\n", + "\n", + "rng = np.random.default_rng(seed=123)\n", + "in_vals = {'q1': 1, 'q2': 1, 'trg': 1}\n", + "for _ in range(10):\n", + " out_vals, phase = do_phased_classical_simulation(throw_out_target, in_vals, rng=rng)\n", + " assert out_vals['q1'] == 1\n", + " assert out_vals['q2'] == 1\n", + " assert 'trg' not in out_vals\n", + " if phase == 1:\n", + " print(\"✓\", end=' ')\n", + " else:\n", + " print(f\"Bad phase: {phase}\")" + ] + }, + { + "cell_type": "markdown", + "id": "368810bd-5ffb-46c6-8a83-6db6d09dcd78", + "metadata": {}, + "source": [ + "A phase on our computational basis state will result in *relative phases amongst* the computational basis states when this operation is called on a register in superposition, so these spurious phases must be fixed." + ] + }, + { + "cell_type": "markdown", + "id": "1dba8338-21d9-49ff-b19a-845663779039", + "metadata": {}, + "source": [ + "## MBUC circuit for $\\mathrm{And}^\\dagger$\n", + "\n", + "So simply measuring the bit in an orthogonal basis and throwing it away hasn't worked. The fix here is straightforward: a phase is encountered when the target bit is `1` and the random measurement outcome is also `1`, so we can flip it back. We flip the phase conditioned on 1) the two control qubits being `1` and 2) the classical measurement result being `1`. The first condition can be achieved with a `CZ`. We use a classically-controlled `CZ` to implement conditions (1) *and* (2) with only a Clifford operation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6717ad61-132d-4df7-9d7e-63a38ddcf233", + "metadata": {}, + "outputs": [], + "source": [ + "bb = BloqBuilder()\n", + "q1 = bb.add_register('q1', 1)\n", + "q2 = bb.add_register('q2', 1)\n", + "trg = bb.add_register(Register('trg', QBit(), side=Side.LEFT))\n", + "\n", + "ctrg = bb.add(MeasX(), q=trg)\n", + "classically_controlled_cz = CZ().controlled(CtrlSpec(qdtypes=[CBit()]))\n", + "ctrg, q1, q2 = bb.add(\n", + " classically_controlled_cz,\n", + " **{'ctrl': ctrg,\n", + " 'q1': q1,\n", + " 'q2': q2\n", + " }\n", + ")\n", + "bb.add(Discard(), c=ctrg)\n", + "\n", + "mbuc_target = bb.finalize(q1=q1, q2=q2)\n", + "show_bloq(mbuc_target, 'musical_score')" + ] + }, + { + "cell_type": "markdown", + "id": "2d2455c1-715f-4040-a7d1-a9197e9b1992", + "metadata": {}, + "source": [ + "## Fuzz testing MBUC\n", + "\n", + "We can continue to use random measurement results in simulation to \"fuzz test\" our construction. Here, all ten runs pass our check." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65b0fbf5-f341-452f-af32-74e309207e48", + "metadata": {}, + "outputs": [], + "source": [ + "rng = np.random.default_rng(seed=123)\n", + "in_vals = {'q1': 1, 'q2': 1, 'trg': 1}\n", + "for _ in range(10):\n", + " out_vals, phase = do_phased_classical_simulation(mbuc_target, in_vals, rng=rng)\n", + " assert out_vals['q1'] == 1\n", + " assert out_vals['q2'] == 1\n", + " assert 'trg' not in out_vals\n", + " if phase == 1:\n", + " print(\"✓\", end=' ')\n", + " else:\n", + " print(f\"Bad phase: {phase}\")" + ] + }, + { + "cell_type": "markdown", + "id": "4c6a8409-0dff-47ec-8a44-efa7edcc157c", + "metadata": {}, + "source": [ + "## Exhaustive testing of MBUC\n", + "\n", + "With some additional work, we can inject particular patterns of measurement results to check all possible cases. For circuits with a small number of `MeasX` bloqs, this can be more valuable than fuzz testing. The exhaustive number of cases grows exponentially in the number of measured bits." + ] + }, + { + "cell_type": "markdown", + "id": "aa8e8503-242f-42c3-9f89-ae24a4dd7934", + "metadata": {}, + "source": [ + "#### Preparation: find the bloq index of our measurement operation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0e233b4-2ec0-4c37-9b3f-101be212e673", + "metadata": {}, + "outputs": [], + "source": [ + "# Prep work: find the bloq instance indices of measurement operations.\n", + "# Here, there's only one; but this code snippet will work for MBUC circuits\n", + "# with additional MeasX bloqs\n", + "cbloq = mbuc_target\n", + "meas_binst_is = [binst.i for binst in cbloq.bloq_instances if binst.bloq_is(MeasX)]\n", + "assert len(meas_binst_is) == 1, 'this circuit only has one'\n", + "meas_binst_i = meas_binst_is[0]\n", + "meas_binst_i" + ] + }, + { + "cell_type": "markdown", + "id": "60ec0fd0-1b48-45d5-80da-de4cac34e9aa", + "metadata": {}, + "source": [ + "### Loop over inputs *and* measurement results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "98c2a2fb-125e-4388-a5a2-f905083cbf31", + "metadata": {}, + "outputs": [], + "source": [ + "from qualtran.simulation.classical_sim import PhasedClassicalSimState\n", + "import itertools\n", + "\n", + "for q1, q2 in itertools.product(range(2), repeat=2):\n", + " trg = int(q1==1 and q2 == 1)\n", + " print(f'{q1=}, {q2=}, {trg=}')\n", + "\n", + " for meas_result in [0, 1]:\n", + " print(f' meas {meas_result}', end=' ')\n", + " fixed_rnd_vals = {meas_binst_i: meas_result}\n", + " sim = PhasedClassicalSimState.from_cbloq(\n", + " cbloq, \n", + " vals={'q1': q1, 'q2': q2, 'trg': trg},\n", + " fixed_rnd_vals={meas_binst_i: meas_result}\n", + " )\n", + " out_vals = sim.simulate()\n", + " \n", + " assert out_vals['q1'] == q1\n", + " assert out_vals['q2'] == q2\n", + " assert 'trg' not in out_vals\n", + " assert phase == 1.0\n", + " print(' ✓')" + ] + }, + { + "cell_type": "markdown", + "id": "f5b50da9-cbf8-4663-9cb2-164675c079ff", + "metadata": {}, + "source": [ + "### Inspecting the phase during simulation\n", + "\n", + "For visibility into the progress of the simulation, we extend the `step` method of the simulator to print out the current phase of the system. We've also modified the exhaustive loop to use `itertools.product` so this code snippet can handle circuits with multiple `MeasX` gates (with exponential scaling). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20a732fe-74d8-43ea-80bf-f7c3d7b09e3f", + "metadata": {}, + "outputs": [], + "source": [ + "class DebugPhasedClassicalSim(PhasedClassicalSimState):\n", + " \"\"\"Phased-classical simulator that prints debug information.\"\"\"\n", + " \n", + " def step(self):\n", + " \"\"\"At each step, print a brief representation of the current phase.\"\"\"\n", + " super().step()\n", + " if sim.phase == 1.0:\n", + " print('+', end='')\n", + " elif sim.phase == -1.0:\n", + " print('-', end='')\n", + " else:\n", + " print('?', end='')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47c2db6d-b7ba-4710-9406-e1737d34cef1", + "metadata": {}, + "outputs": [], + "source": [ + "meas_binst_is = [binst.i for binst in cbloq.bloq_instances if binst.bloq_is(MeasX)]\n", + "\n", + "for q1, q2 in itertools.product(range(2), repeat=2):\n", + " trg = int(q1==1 and q2 == 1)\n", + " print(f'{q1=}, {q2=}, {trg=}')\n", + "\n", + " for meas_result in itertools.product(range(2), repeat=len(meas_binst_is)):\n", + " print(f' meas {meas_result}', end=' ')\n", + " fixed_rnd_vals = {binst_i: meas_result[j] for j, binst_i in enumerate(meas_binst_is)}\n", + "\n", + " sim = DebugPhasedClassicalSim.from_cbloq(\n", + " cbloq,\n", + " {'q1': q1, 'q2': q2, 'trg': trg},\n", + " fixed_rnd_vals=fixed_rnd_vals\n", + " )\n", + " out_vals = sim.simulate()\n", + " \n", + " assert out_vals['q1'] == q1\n", + " assert out_vals['q2'] == q2\n", + " assert 'trg' not in out_vals\n", + " assert phase == 1.0\n", + " print(' ✓')" + ] + }, + { + "cell_type": "markdown", + "id": "89b36959-ce99-4b65-a89a-8175b4849127", + "metadata": {}, + "source": [ + "Note that the phase is unaffected for all cases except when the `target` bit is `1` *and* the measurement result is `1`. Note that it is immediately fixed by the classically-controlled CZ." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/qualtran/simulation/classical_sim.py b/qualtran/simulation/classical_sim.py index 927c93e00a..a771984657 100644 --- a/qualtran/simulation/classical_sim.py +++ b/qualtran/simulation/classical_sim.py @@ -13,6 +13,7 @@ # limitations under the License. """Functionality for the `Bloq.call_classically(...)` protocol.""" +import abc import itertools from typing import ( Any, @@ -28,6 +29,7 @@ Union, ) +import attrs import networkx as nx import numpy as np import sympy @@ -43,13 +45,13 @@ Signature, Soquet, ) -from qualtran._infra.composite_bloq import _binst_to_cxns +from qualtran._infra.composite_bloq import _binst_to_cxns, _get_soquet if TYPE_CHECKING: from qualtran import CompositeBloq, QCDType ClassicalValT = Union[int, np.integer, NDArray[np.integer]] -ClassicalValRetT = Union[int, np.integer, NDArray[np.integer]] +ClassicalValRetT = Union[int, np.integer, NDArray[np.integer], 'ClassicalValDistribution'] def _numpy_dtype_from_qlt_dtype(dtype: 'QCDType') -> Type: @@ -106,6 +108,49 @@ def _get_in_vals( return arg +@attrs.frozen(hash=False) +class ClassicalValDistribution: + """Return this if ... + + Args: + a: An array of choices, or `np.arange` if an integer is given. This is the `a` parameter + to `np.random.Generator.choice()`. + p: An array of probabilities. If not supplied, the uniform distribution is assumed. This + is the `p` parameter to `np.random.Generator.choice()`. + """ + + a: Union[int, np.typing.ArrayLike] + p: Optional[np.typing.ArrayLike] = None + + +class _RandomValHandler(metaclass=abc.ABCMeta): + + @abc.abstractmethod + def get(self, binst: 'BloqInstance', a, p) -> Any: ... + + +class _RandomRandomValHandler(_RandomValHandler): + def __init__(self, rng): + self._gen = rng + + def get(self, binst, a, p): + return self._gen.choice(a, p=p) + + +class _FixedRandomValHandler(_RandomValHandler): + def __init__(self, binst_i_to_val: Dict[int, Any]): + self._binst_i_to_val = binst_i_to_val + + def get(self, binst, a, p): + return self._binst_i_to_val[binst.i] + + +class _BannedRandomValHandler(_RandomValHandler): + + def get(self, binst: 'BloqInstance', a, p) -> Any: + raise ValueError(f"{binst} has non-deterministic classical action. TODO: advice.") + + class ClassicalSimState: """A mutable class for classically simulating composite bloqs. @@ -138,10 +183,12 @@ def __init__( signature: 'Signature', binst_graph: nx.DiGraph, vals: Mapping[str, Union[sympy.Symbol, ClassicalValT]], + rnd_handler: '_RandomValHandler' = _BannedRandomValHandler(), ): self._signature = signature self._binst_graph = binst_graph self._binst_iter = nx.topological_sort(self._binst_graph) + self._rnd_handler = rnd_handler # Keep track of each soquet's bit array. Initialize with LeftDangle self.soq_assign: Dict[Soquet, ClassicalValT] = {} @@ -206,6 +253,8 @@ def _update_assign_from_vals( else: # `val` is one value. + if isinstance(val, ClassicalValDistribution): + val = self._rnd_handler.get(binst, val.a, val.p) reg.dtype.assert_valid_classical_val(val, debug_str) soq = Soquet(binst, reg) self.soq_assign[soq] = val @@ -298,6 +347,17 @@ def simulate(self) -> Dict[str, 'ClassicalValT']: return self.finalize() +@attrs.frozen +class MeasurementPhase: + """Sentinel value to return from `Bloq.basis_state_phase` if a phase should be applied based on a measurement outcome. + + This can be used in special circumstances to verify measurement-based uncomputation (MBUC). + """ + + reg_name: str + idx: Tuple[int, ...] = () + + class PhasedClassicalSimState(ClassicalSimState): """A mutable class for classically simulating composite bloqs with phase tracking. @@ -328,16 +388,22 @@ def __init__( signature: 'Signature', binst_graph: nx.DiGraph, vals: Mapping[str, Union[sympy.Symbol, ClassicalValT]], - *, + rnd_handler: '_RandomValHandler', phase: complex = 1.0, ): - super().__init__(signature=signature, binst_graph=binst_graph, vals=vals) + super().__init__( + signature=signature, binst_graph=binst_graph, vals=vals, rnd_handler=rnd_handler + ) _assert_valid_phase(phase) self.phase = phase @classmethod def from_cbloq( - cls, cbloq: 'CompositeBloq', vals: Mapping[str, Union[sympy.Symbol, ClassicalValT]] + cls, + cbloq: 'CompositeBloq', + vals: Mapping[str, Union[sympy.Symbol, ClassicalValT]], + rng=None, + fixed_rnd_vals=None, ) -> 'PhasedClassicalSimState': """Initiate a classical simulation from a CompositeBloq. @@ -349,7 +415,23 @@ def from_cbloq( Returns: A new classical sim state. """ - return cls(signature=cbloq.signature, binst_graph=cbloq._binst_graph, vals=vals) + if rng is not None and fixed_rnd_vals is not None: + raise ValueError("Supply either `seed` or `fixed_rnd_vals`, not both.") + + rnd_handler: _RandomValHandler + if rng is not None: + rnd_handler = _RandomRandomValHandler(rng=rng) + elif fixed_rnd_vals is not None: + rnd_handler = _FixedRandomValHandler(binst_i_to_val=fixed_rnd_vals) + else: + rnd_handler = _BannedRandomValHandler() + + return cls( + signature=cbloq.signature, + binst_graph=cbloq._binst_graph, + vals=vals, + rnd_handler=rnd_handler, + ) def _binst_basis_state_phase(self, binst, in_vals): """Call `basis_state_phase` on a given bloq instance. @@ -359,7 +441,25 @@ def _binst_basis_state_phase(self, binst, in_vals): """ bloq = binst.bloq bloq_phase = bloq.basis_state_phase(**in_vals) - if bloq_phase is not None: + if isinstance(bloq_phase, MeasurementPhase): + # In this special case, there is a coupling between the classical result and the + # phase result (because the classical result is stochastic). We look up the measurement + # result and apply a phase if it is `1`. + meas_result = self.soq_assign[ + _get_soquet( + binst=binst, + reg_name=bloq_phase.reg_name, + right=True, + idx=bloq_phase.idx, + binst_graph=self._binst_graph, + ) + ] + if meas_result == 1: + self.phase *= -1.0 + else: + # Measurement result of 0, phase of +1 + pass + elif bloq_phase is not None: _assert_valid_phase(bloq_phase) self.phase *= bloq_phase else: @@ -398,7 +498,9 @@ def _assert_valid_phase(p: complex, atol: float = 1e-8): raise ValueError(f"Phases must have unit modulus. Found {p}.") -def do_phased_classical_simulation(bloq: 'Bloq', vals: Mapping[str, 'ClassicalValT']): +def do_phased_classical_simulation( + bloq: 'Bloq', vals: Mapping[str, 'ClassicalValT'], rng: Optional['np.random.Generator'] = None +): """Do a phased classical simulation of the bloq. This provides a simple interface to `PhasedClassicalSimState`. Advanced users @@ -408,13 +510,16 @@ def do_phased_classical_simulation(bloq: 'Bloq', vals: Mapping[str, 'ClassicalVa bloq: The bloq to simulate vals: A mapping from input register name to initial classical values. The initial phase is assumed to be 1.0. + rng: A numpy random generator (e.g. from `np.random.default_rng()`). This function + will use this generator to supply random values from certain phased-classical operations + like `MeasX`. If not supplied, stochastic operations will result in an error. Returns: final_vals: A mapping of output register name to final classical values. phase: The final phase. """ cbloq = bloq.as_composite_bloq() - sim = PhasedClassicalSimState.from_cbloq(cbloq, vals=vals) + sim = PhasedClassicalSimState.from_cbloq(cbloq, vals=vals, rng=rng) final_vals = sim.simulate() phase = sim.phase return final_vals, phase