Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion qualtran/bloqs/basic_gates/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .swap import CSwap, Swap, TwoBitCSwap, TwoBitSwap
from .t_gate import TGate
from .toffoli import Toffoli
from .x_basis import MinusEffect, MinusState, PlusEffect, PlusState, XGate
from .x_basis import MeasX, MinusEffect, MinusState, PlusEffect, PlusState, XGate
from .y_gate import CYGate, YGate
from .z_basis import (
CZ,
Expand Down
59 changes: 58 additions & 1 deletion qualtran/bloqs/basic_gates/x_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
bloq_example,
BloqBuilder,
BloqDocSpec,
CBit,
ConnectionT,
CtrlSpec,
QBit,
Expand All @@ -33,6 +34,12 @@
SoquetT,
)
from qualtran.drawing import directional_text_box, Text, WireSymbol
from qualtran.simulation.classical_sim import (
ClassicalValDistribution,
ClassicalValRetT,
ClassicalValT,
MeasurementPhase,
)

if TYPE_CHECKING:
import cirq
Expand All @@ -41,7 +48,6 @@
from pennylane.wires import Wires

from qualtran.cirq_interop import CirqQuregT
from qualtran.simulation.classical_sim import ClassicalValT

_PLUS = np.ones(2, dtype=np.complex128) / np.sqrt(2)
_MINUS = np.array([1, -1], dtype=np.complex128) / np.sqrt(2)
Expand Down Expand Up @@ -270,3 +276,54 @@ def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSym
return Text('X')

return ModPlus()

def __str__(self):
return 'X'


@frozen
class MeasX(Bloq):
"""Measure a qubit in the X basis.

Registers:
q [LEFT]: The qubit to measure.
c [RIGHT]: The classical measurement result.
"""

@cached_property
def signature(self) -> 'Signature':
return Signature(
[Register('q', QBit(), side=Side.LEFT), Register('c', CBit(), side=Side.RIGHT)]
)

def on_classical_vals(self, q: int) -> Dict[str, 'ClassicalValRetT']:
if q not in [0, 1]:
raise ValueError(f"Invalid classical value encountered in {self}: {q}")
return {'c': ClassicalValDistribution(2)}

def basis_state_phase(self, q: int) -> Union[complex, MeasurementPhase]:
if q == 0:
return 1
if q == 1:
return MeasurementPhase(reg_name='c')
raise ValueError(f"Invalid classical value encountered in {self}: {q}")

def my_tensors(
self, incoming: Dict[str, 'ConnectionT'], outgoing: Dict[str, 'ConnectionT']
) -> List['qtn.Tensor']:
import quimb.tensor as qtn

from qualtran.simulation.tensor import DiscardInd

data = np.array(
[
[[0.5 + 0.0j, 0.5 + 0.0j], [0.5 + 0.0j, -0.5 + 0.0j]],
[[0.5 + 0.0j, -0.5 + 0.0j], [0.5 + 0.0j, 0.5 + 0.0j]],
]
)

q_trace = qtn.rand_uuid('q_trace')
t = qtn.Tensor(
data=data, inds=[(incoming['q'], 0), (q_trace, 0), (outgoing['c'], 0)], tags=[str(self)]
)
return [t, DiscardInd((q_trace, 0))]
84 changes: 83 additions & 1 deletion qualtran/bloqs/basic_gates/x_basis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,24 @@
# limitations under the License.
import cirq
import numpy as np
import pytest

from qualtran import BloqBuilder
from qualtran.bloqs.basic_gates import MinusState, PlusEffect, PlusState, XGate
from qualtran.bloqs.basic_gates import (
MeasX,
MinusState,
OneState,
PlusEffect,
PlusState,
XGate,
ZeroState,
)
from qualtran.resource_counting import GateCounts, get_cost_value, QECGatesCost
from qualtran.simulation.classical_sim import (
do_phased_classical_simulation,
format_classical_truth_table,
get_classical_truth_table,
MeasurementPhase,
)


Expand Down Expand Up @@ -119,3 +130,74 @@ def _keep_and(b):
bloq = XGate().controlled(CtrlSpec(qdtypes=QUInt(n), cvs=1))
_, sigma = bloq.call_graph(keep=_keep_and)
assert sigma == {And(): n - 1, CNOT(): 1, And().adjoint(): n - 1, XGate(): 4 * (n - 1)}


def test_meas_x_classical_sim() -> None:
bloq = MeasX()

with pytest.raises(ValueError, match='MeasX imparts a phase'):
_ = bloq.call_classically(q=0)
Comment on lines +138 to +139
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


with pytest.raises(ValueError, match='Invalid classical value'):
_ = bloq.on_classical_vals(q=2)

rng = np.random.default_rng(seed=12345)
results = [do_phased_classical_simulation(bloq, {'q': 0}, rng=rng) for _ in range(100)]

# Assert measurements are random
assert all(c[0]['c'] in {0, 1} for c in results)
assert any(c[0]['c'] == 0 for c in results)
assert any(c[0]['c'] == 1 for c in results)
# Assert phase is 1
assert all(c[1] == 1 for c in results)

rng = np.random.default_rng(seed=12345)
results = [do_phased_classical_simulation(bloq, {'q': 1}, rng=rng) for _ in range(100)]
# Assert measurements are random
assert all(c[0]['c'] in {0, 1} for c in results)
assert any(c[0]['c'] == 0 for c in results)
assert any(c[0]['c'] == 1 for c in results)
# Assert phase is -1 only if measurement is 1
assert all(c[1] == -1 for c in results if c[0]['c'] == 1)
assert all(c[1] == 1 for c in results if c[0]['c'] == 0)


def test_meas_x_basis_state_phase() -> None:
bloq = MeasX()
assert bloq.basis_state_phase(0) == 1
assert bloq.basis_state_phase(1) == MeasurementPhase(reg_name='c')

with pytest.raises(ValueError, match='Invalid classical value'):
_ = bloq.basis_state_phase(2)


def test_meas_z_supertensor():
with pytest.raises(ValueError, match=r'.*superoperator.*'):
MeasX().tensor_contract()

# Zero -> Fully mixed state
bb = BloqBuilder()
q = bb.add(ZeroState())
c = bb.add(MeasX(), q=q)
cbloq = bb.finalize(c=c)
rho = cbloq.tensor_contract(superoperator=True)
should_be = np.asarray([[0.5, 0], [0, 0.5]])
np.testing.assert_allclose(rho, should_be, atol=1e-8)

# One -> Fully mixed state
bb = BloqBuilder()
q = bb.add(OneState())
c = bb.add(MeasX(), q=q)
cbloq = bb.finalize(c=c)
rho = cbloq.tensor_contract(superoperator=True)
should_be = np.asarray([[0.5, 0], [0, 0.5]])
np.testing.assert_allclose(rho, should_be, atol=1e-8)

# Plus measurement -> deterministic zero
bb = BloqBuilder()
q = bb.add(PlusState())
c = bb.add(MeasX(), q=q)
cbloq = bb.finalize(c=c)
rho = cbloq.tensor_contract(superoperator=True)
should_be = np.asarray([[1, 0], [0, 0]])
np.testing.assert_allclose(rho, should_be, atol=1e-8)