diff --git a/qermit/taskgraph/mitex.py b/qermit/taskgraph/mitex.py index 4e51423a..f7f09003 100644 --- a/qermit/taskgraph/mitex.py +++ b/qermit/taskgraph/mitex.py @@ -93,8 +93,12 @@ def get_basic_measurement_circuit( :param string: Qubit Pauli String to be measured :type string: QubitPauliString - :return: Measurement circuit for appending on some ansatz - :rtype: Circuit + :return: Tuple of measurement circuit for appending on some ansatz + and MeasurementInfo. Each entry of th measurement information + contains a QubitPauliString, the bits required to take expectation + over in resulting result and a bool signifying whether expectation + should be inverted when taking result. + :rtype: Tuple[Circuit, MeasurementInfo] """ measurement_circuit = Circuit() measured_qbs = [] @@ -152,11 +156,11 @@ def task( circ = circuit.copy() # tuple, first entry is measurement circuit for appending # second entry is MeasurementInfo for deriving expectation - measurement_circuit = get_basic_measurement_circuit(string) - circ.append(measurement_circuit[0]) + measurement_circuit, measurement_info = get_basic_measurement_circuit(string) + circ.append(measurement_circuit) # add new circuit to observable tracker observable_tracker.add_measurement_circuit( - MeasurementCircuit(circ, symbols), [measurement_circuit[1]] + MeasurementCircuit(circ, symbols), [measurement_info] ) # retrieve all measurement circuits, substitute symbols diff --git a/qermit/taskgraph/utils.py b/qermit/taskgraph/utils.py index 8b8697db..ce8d1d68 100644 --- a/qermit/taskgraph/utils.py +++ b/qermit/taskgraph/utils.py @@ -403,10 +403,11 @@ def check_string(self, string: QubitPauliString) -> bool: """ if string not in self._qps_to_indices: return False + if len(self._qps_to_indices[string]) > 0: return True - else: - return False + + return False def get_empty_strings(self) -> List[QubitPauliString]: """ diff --git a/tests/mitex_test.py b/tests/mitex_test.py index 76c8b02b..c58023fc 100644 --- a/tests/mitex_test.py +++ b/tests/mitex_test.py @@ -15,7 +15,7 @@ import copy -from pytket.circuit import Circuit, OpType, Qubit, fresh_symbol # type: ignore +from pytket.circuit import Bit, Circuit, OpType, Qubit, fresh_symbol # type: ignore from pytket.extensions.qiskit import AerBackend # type: ignore from pytket.pauli import Pauli, QubitPauliString # type: ignore from pytket.utils import QubitPauliOperator @@ -32,11 +32,63 @@ collate_circuit_shots_task_gen, filter_observable_tracker_task_gen, gen_compiled_shot_split_MitRes, + get_basic_measurement_circuit, get_expectations_task_gen, split_results_task_gen, ) +def test_get_basic_measurement_circuit() -> None: + + q_0 = Qubit(name="test qubits a", index=0) + q_1 = Qubit(name="test qubits a", index=1) + q_2 = Qubit(name="test qubits b", index=0) + q_3 = Qubit(name="test qubits b", index=1) + + qps_0 = QubitPauliString([q_1, q_2, q_3], [Pauli.Z, Pauli.X, Pauli.I]) + + circuit = Circuit() + + circuit.add_qubit(q_1) + circuit.add_bit(Bit(0)) + circuit.Measure(q_1, Bit(0)) + + circuit.add_qubit(q_2) + circuit.H(q_2) + circuit.add_bit(Bit(1)) + circuit.Measure(q_2, Bit(1)) + + measurement_circuit, (string, bits, meas_bool) = get_basic_measurement_circuit(qps_0) + assert measurement_circuit == circuit + assert bits == [Bit(0), Bit(1)] + assert not meas_bool + assert string == qps_0 + + qps_1 = QubitPauliString([q_0, q_2, q_1], [Pauli.Z, Pauli.Y, Pauli.X]) + + circuit = Circuit() + + circuit.add_qubit(q_0) + circuit.add_bit(Bit(0)) + circuit.Measure(q_0, Bit(0)) + + circuit.add_qubit(q_1) + circuit.add_bit(Bit(1)) + circuit.H(q_1) + circuit.Measure(q_1, Bit(1)) + + circuit.add_qubit(q_2) + circuit.add_bit(Bit(2)) + circuit.Rx(0.5, q_2) + circuit.Measure(q_2, Bit(2)) + + measurement_circuit, (string, bits, meas_bool) = get_basic_measurement_circuit(qps_1) + assert measurement_circuit == circuit + assert bits == [Bit(0), Bit(1), Bit(2)] + assert not meas_bool + assert string == qps_1 + + def test_mitex_cache(): circuit = Circuit(1).X(0) backend = AerBackend()