diff --git a/qualtran/bloqs/basic_gates/__init__.py b/qualtran/bloqs/basic_gates/__init__.py index d89e5080a9..a1ce524e9e 100644 --- a/qualtran/bloqs/basic_gates/__init__.py +++ b/qualtran/bloqs/basic_gates/__init__.py @@ -34,7 +34,7 @@ from .swap import CSwap, Swap, TwoBitCSwap, TwoBitSwap from .t_gate import TGate from .toffoli import Toffoli -from .x_basis import MinusEffect, MinusState, PlusEffect, PlusState, XGate +from .x_basis import MeasX, MinusEffect, MinusState, PlusEffect, PlusState, XGate from .y_gate import CYGate, YGate from .z_basis import ( CZ, diff --git a/qualtran/bloqs/basic_gates/x_basis.py b/qualtran/bloqs/basic_gates/x_basis.py index 5b50f5da8d..d8d9c2f184 100644 --- a/qualtran/bloqs/basic_gates/x_basis.py +++ b/qualtran/bloqs/basic_gates/x_basis.py @@ -24,6 +24,7 @@ bloq_example, BloqBuilder, BloqDocSpec, + CBit, ConnectionT, CtrlSpec, QBit, @@ -33,6 +34,12 @@ SoquetT, ) from qualtran.drawing import directional_text_box, Text, WireSymbol +from qualtran.simulation.classical_sim import ( + ClassicalValDistribution, + ClassicalValRetT, + ClassicalValT, + MeasurementPhase, +) if TYPE_CHECKING: import cirq @@ -41,7 +48,6 @@ from pennylane.wires import Wires from qualtran.cirq_interop import CirqQuregT - from qualtran.simulation.classical_sim import ClassicalValT _PLUS = np.ones(2, dtype=np.complex128) / np.sqrt(2) _MINUS = np.array([1, -1], dtype=np.complex128) / np.sqrt(2) @@ -270,3 +276,54 @@ def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSym return Text('X') return ModPlus() + + def __str__(self): + return 'X' + + +@frozen +class MeasX(Bloq): + """Measure a qubit in the X basis. + + Registers: + q [LEFT]: The qubit to measure. + c [RIGHT]: The classical measurement result. + """ + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [Register('q', QBit(), side=Side.LEFT), Register('c', CBit(), side=Side.RIGHT)] + ) + + def on_classical_vals(self, q: int) -> Dict[str, 'ClassicalValRetT']: + if q not in [0, 1]: + raise ValueError(f"Invalid classical value encountered in {self}: {q}") + return {'c': ClassicalValDistribution(2)} + + def basis_state_phase(self, q: int) -> Union[complex, MeasurementPhase]: + if q == 0: + return 1 + if q == 1: + return MeasurementPhase(reg_name='c') + raise ValueError(f"Invalid classical value encountered in {self}: {q}") + + def my_tensors( + self, incoming: Dict[str, 'ConnectionT'], outgoing: Dict[str, 'ConnectionT'] + ) -> List['qtn.Tensor']: + import quimb.tensor as qtn + + from qualtran.simulation.tensor import DiscardInd + + data = np.array( + [ + [[0.5 + 0.0j, 0.5 + 0.0j], [0.5 + 0.0j, -0.5 + 0.0j]], + [[0.5 + 0.0j, -0.5 + 0.0j], [0.5 + 0.0j, 0.5 + 0.0j]], + ] + ) + + q_trace = qtn.rand_uuid('q_trace') + t = qtn.Tensor( + data=data, inds=[(incoming['q'], 0), (q_trace, 0), (outgoing['c'], 0)], tags=[str(self)] + ) + return [t, DiscardInd((q_trace, 0))] diff --git a/qualtran/bloqs/basic_gates/x_basis_test.py b/qualtran/bloqs/basic_gates/x_basis_test.py index edcde41797..8f16d3ec2b 100644 --- a/qualtran/bloqs/basic_gates/x_basis_test.py +++ b/qualtran/bloqs/basic_gates/x_basis_test.py @@ -13,13 +13,24 @@ # limitations under the License. import cirq import numpy as np +import pytest from qualtran import BloqBuilder -from qualtran.bloqs.basic_gates import MinusState, PlusEffect, PlusState, XGate +from qualtran.bloqs.basic_gates import ( + MeasX, + MinusState, + OneState, + PlusEffect, + PlusState, + XGate, + ZeroState, +) from qualtran.resource_counting import GateCounts, get_cost_value, QECGatesCost from qualtran.simulation.classical_sim import ( + do_phased_classical_simulation, format_classical_truth_table, get_classical_truth_table, + MeasurementPhase, ) @@ -119,3 +130,74 @@ def _keep_and(b): bloq = XGate().controlled(CtrlSpec(qdtypes=QUInt(n), cvs=1)) _, sigma = bloq.call_graph(keep=_keep_and) assert sigma == {And(): n - 1, CNOT(): 1, And().adjoint(): n - 1, XGate(): 4 * (n - 1)} + + +def test_meas_x_classical_sim() -> None: + bloq = MeasX() + + with pytest.raises(ValueError, match='MeasX imparts a phase'): + _ = bloq.call_classically(q=0) + + with pytest.raises(ValueError, match='Invalid classical value'): + _ = bloq.on_classical_vals(q=2) + + rng = np.random.default_rng(seed=12345) + results = [do_phased_classical_simulation(bloq, {'q': 0}, rng=rng) for _ in range(100)] + + # Assert measurements are random + assert all(c[0]['c'] in {0, 1} for c in results) + assert any(c[0]['c'] == 0 for c in results) + assert any(c[0]['c'] == 1 for c in results) + # Assert phase is 1 + assert all(c[1] == 1 for c in results) + + rng = np.random.default_rng(seed=12345) + results = [do_phased_classical_simulation(bloq, {'q': 1}, rng=rng) for _ in range(100)] + # Assert measurements are random + assert all(c[0]['c'] in {0, 1} for c in results) + assert any(c[0]['c'] == 0 for c in results) + assert any(c[0]['c'] == 1 for c in results) + # Assert phase is -1 only if measurement is 1 + assert all(c[1] == -1 for c in results if c[0]['c'] == 1) + assert all(c[1] == 1 for c in results if c[0]['c'] == 0) + + +def test_meas_x_basis_state_phase() -> None: + bloq = MeasX() + assert bloq.basis_state_phase(0) == 1 + assert bloq.basis_state_phase(1) == MeasurementPhase(reg_name='c') + + with pytest.raises(ValueError, match='Invalid classical value'): + _ = bloq.basis_state_phase(2) + + +def test_meas_z_supertensor(): + with pytest.raises(ValueError, match=r'.*superoperator.*'): + MeasX().tensor_contract() + + # Zero -> Fully mixed state + bb = BloqBuilder() + q = bb.add(ZeroState()) + c = bb.add(MeasX(), q=q) + cbloq = bb.finalize(c=c) + rho = cbloq.tensor_contract(superoperator=True) + should_be = np.asarray([[0.5, 0], [0, 0.5]]) + np.testing.assert_allclose(rho, should_be, atol=1e-8) + + # One -> Fully mixed state + bb = BloqBuilder() + q = bb.add(OneState()) + c = bb.add(MeasX(), q=q) + cbloq = bb.finalize(c=c) + rho = cbloq.tensor_contract(superoperator=True) + should_be = np.asarray([[0.5, 0], [0, 0.5]]) + np.testing.assert_allclose(rho, should_be, atol=1e-8) + + # Plus measurement -> deterministic zero + bb = BloqBuilder() + q = bb.add(PlusState()) + c = bb.add(MeasX(), q=q) + cbloq = bb.finalize(c=c) + rho = cbloq.tensor_contract(superoperator=True) + should_be = np.asarray([[1, 0], [0, 0]]) + np.testing.assert_allclose(rho, should_be, atol=1e-8)