Skip to content

Commit 0934472

Browse files
authored
[Verifier] Verify each circuit transformation step (#176)
* [Verifier] Verify each circuit transformation step * [Generator] Switch back to the previous canonical representation version (#175) * [Verifier] Verify all transformation steps * code format
1 parent 367ca9e commit 0934472

11 files changed

+470
-93
lines changed

src/quartz/circuitseq/circuitgate.cpp

+75
Original file line numberDiff line numberDiff line change
@@ -307,4 +307,79 @@ std::string CircuitGate::to_qasm_style_string(Context *ctx,
307307
return result;
308308
}
309309

310+
bool CircuitGate::equivalent(
311+
const CircuitGate *this_gate, const CircuitGate *other_gate,
312+
std::unordered_map<CircuitWire *, CircuitWire *> &wires_mapping,
313+
bool update_mapping, std::queue<CircuitWire *> *wires_to_search,
314+
bool backward) {
315+
if (this_gate->gate->tp != other_gate->gate->tp) {
316+
return false;
317+
}
318+
if (this_gate->input_wires.size() != other_gate->input_wires.size() ||
319+
this_gate->output_wires.size() != other_gate->output_wires.size()) {
320+
return false;
321+
}
322+
if (backward) {
323+
for (int j = 0; j < (int)this_gate->output_wires.size(); j++) {
324+
assert(this_gate->output_wires[j]->is_qubit());
325+
// The output wire must have been mapped.
326+
assert(wires_mapping.count(this_gate->output_wires[j]) != 0);
327+
if (wires_mapping[this_gate->output_wires[j]] !=
328+
other_gate->output_wires[j]) {
329+
return false;
330+
}
331+
}
332+
if (update_mapping) {
333+
// Map input wires
334+
for (int j = 0; j < (int)this_gate->input_wires.size(); j++) {
335+
assert(wires_mapping.count(this_gate->input_wires[j]) == 0);
336+
wires_mapping[this_gate->input_wires[j]] = other_gate->input_wires[j];
337+
wires_to_search->push(this_gate->input_wires[j]);
338+
}
339+
} else {
340+
// Verify mapping
341+
for (int j = 0; j < (int)this_gate->input_wires.size(); j++) {
342+
if (wires_mapping[this_gate->input_wires[j]] !=
343+
other_gate->input_wires[j]) {
344+
return false;
345+
}
346+
}
347+
}
348+
} else {
349+
for (int j = 0; j < (int)this_gate->input_wires.size(); j++) {
350+
if (this_gate->input_wires[j]->is_qubit()) {
351+
// The input wire must have been mapped.
352+
assert(wires_mapping.count(this_gate->input_wires[j]) != 0);
353+
if (wires_mapping[this_gate->input_wires[j]] !=
354+
other_gate->input_wires[j]) {
355+
return false;
356+
}
357+
} else {
358+
// parameters should not be mapped
359+
if (other_gate->input_wires[j] != this_gate->input_wires[j]) {
360+
return false;
361+
}
362+
}
363+
}
364+
if (update_mapping) {
365+
// Map output wires
366+
for (int j = 0; j < (int)this_gate->output_wires.size(); j++) {
367+
assert(wires_mapping.count(this_gate->output_wires[j]) == 0);
368+
wires_mapping[this_gate->output_wires[j]] = other_gate->output_wires[j];
369+
wires_to_search->push(this_gate->output_wires[j]);
370+
}
371+
} else {
372+
// Verify mapping
373+
for (int j = 0; j < (int)this_gate->output_wires.size(); j++) {
374+
if (wires_mapping[this_gate->output_wires[j]] !=
375+
other_gate->output_wires[j]) {
376+
return false;
377+
}
378+
}
379+
}
380+
}
381+
// Equivalent
382+
return true;
383+
}
384+
310385
} // namespace quartz

src/quartz/circuitseq/circuitgate.h

+29
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include "../utils/utils.h"
55

66
#include <istream>
7+
#include <queue>
8+
#include <unordered_map>
79
#include <vector>
810

911
namespace quartz {
@@ -48,6 +50,33 @@ class CircuitGate {
4850
[[nodiscard]] std::string to_qasm_style_string(Context *ctx,
4951
int param_precision) const;
5052

53+
/**
54+
* Check if two gates in two circuits are equivalent given a mapping of
55+
* circuit wires.
56+
* @param this_gate A gate in the first circuit.
57+
* @param other_gate A gate in the second circuit.
58+
* @param wires_mapping A mapping from the wires in the first circuit to the
59+
* wires in the second circuit.
60+
* @param update_mapping If true, map the output wires, and push the output
61+
* wires in the first circuit to the queue;
62+
* if false, also check if the output wires are already correctly mapped
63+
* (only return true if the gates are equivalent and the output wires are
64+
* correctly mapped).
65+
* @param wires_to_search When |update_mapping| is true and the two gates
66+
* are equivalent under the wires mapping, store the new wires to search in
67+
* the topological sort procedure. Otherwise, this parameter has no effect.
68+
* When |update_mapping| is false, this parameter can be nullptr.
69+
* @param backward If true, the topological sort is performed from the end
70+
* of the circuit to the beginning. Only the output wires instead of the
71+
* input wires are compared when |update_mapping| is true.
72+
* @return True iff the two gates are equivalent under the wires mapping.
73+
*/
74+
static bool
75+
equivalent(const CircuitGate *this_gate, const CircuitGate *other_gate,
76+
std::unordered_map<CircuitWire *, CircuitWire *> &wires_mapping,
77+
bool update_mapping, std::queue<CircuitWire *> *wires_to_search,
78+
bool backward = false);
79+
5180
std::vector<CircuitWire *> input_wires; // Include parameters!
5281
std::vector<CircuitWire *> output_wires;
5382

src/quartz/circuitseq/circuitseq.cpp

+56-60
Original file line numberDiff line numberDiff line change
@@ -53,35 +53,14 @@ bool CircuitSeq::fully_equivalent(const CircuitSeq &other) const {
5353
}
5454
std::unordered_map<CircuitWire *, CircuitWire *> wires_mapping;
5555
for (int i = 0; i < (int)wires.size(); i++) {
56-
wires_mapping[other.wires[i].get()] = wires[i].get();
56+
wires_mapping[wires[i].get()] = other.wires[i].get();
5757
}
5858
for (int i = 0; i < (int)gates.size(); i++) {
59-
if (gates[i]->gate->tp != other.gates[i]->gate->tp) {
60-
return false;
61-
}
62-
if (gates[i]->input_wires.size() != other.gates[i]->input_wires.size() ||
63-
gates[i]->output_wires.size() != other.gates[i]->output_wires.size()) {
59+
if (!CircuitGate::equivalent(gates[i].get(), other.gates[i].get(),
60+
wires_mapping,
61+
/*update_mapping=*/false, nullptr)) {
6462
return false;
6563
}
66-
for (int j = 0; j < (int)gates[i]->input_wires.size(); j++) {
67-
if (other.gates[i]->input_wires[j]->is_qubit()) {
68-
if (wires_mapping[other.gates[i]->input_wires[j]] !=
69-
gates[i]->input_wires[j]) {
70-
return false;
71-
}
72-
} else {
73-
// parameters should not be mapped
74-
if (other.gates[i]->input_wires[j] != gates[i]->input_wires[j]) {
75-
return false;
76-
}
77-
}
78-
}
79-
for (int j = 0; j < (int)gates[i]->output_wires.size(); j++) {
80-
if (wires_mapping[other.gates[i]->output_wires[j]] !=
81-
gates[i]->output_wires[j]) {
82-
return false;
83-
}
84-
}
8564
}
8665
return true;
8766
}
@@ -123,35 +102,11 @@ bool CircuitSeq::topologically_equivalent(const CircuitSeq &other) const {
123102
if (!--gate_remaining_in_degree[this_gate]) {
124103
// Check if this gate is the same as the other gate
125104
auto other_gate = other_wire->output_gates[i];
126-
if (this_gate->gate->tp != other_gate->gate->tp) {
105+
if (!CircuitGate::equivalent(this_gate, other_gate, wires_mapping,
106+
/*update_mapping=*/true,
107+
&wires_to_search)) {
127108
return false;
128109
}
129-
if (this_gate->input_wires.size() != other_gate->input_wires.size() ||
130-
this_gate->output_wires.size() != other_gate->output_wires.size()) {
131-
return false;
132-
}
133-
for (int j = 0; j < (int)this_gate->input_wires.size(); j++) {
134-
if (this_gate->input_wires[j]->is_qubit()) {
135-
// The input wire must have been mapped.
136-
assert(wires_mapping.count(this_gate->input_wires[j]) != 0);
137-
if (wires_mapping[this_gate->input_wires[j]] !=
138-
other_gate->input_wires[j]) {
139-
return false;
140-
}
141-
} else {
142-
// parameters should not be mapped
143-
if (other_gate->input_wires[j] != this_gate->input_wires[j]) {
144-
return false;
145-
}
146-
}
147-
}
148-
// Map output wires
149-
for (int j = 0; j < (int)this_gate->output_wires.size(); j++) {
150-
assert(wires_mapping.count(this_gate->output_wires[j]) == 0);
151-
wires_mapping[this_gate->output_wires[j]] =
152-
other_gate->output_wires[j];
153-
wires_to_search.push(this_gate->output_wires[j]);
154-
}
155110
}
156111
}
157112
}
@@ -484,6 +439,16 @@ bool CircuitSeq::remove_gate(CircuitGate *circuit_gate) {
484439
return true;
485440
}
486441

442+
bool CircuitSeq::remove_gate_near_end(CircuitGate *circuit_gate) {
443+
auto gate_pos = std::find_if(
444+
gates.rbegin(), gates.rend(),
445+
[&](std::unique_ptr<CircuitGate> &p) { return p.get() == circuit_gate; });
446+
if (gate_pos == gates.rend()) {
447+
return false;
448+
}
449+
return remove_gate((int)(gates.rend() - gate_pos) - 1);
450+
}
451+
487452
bool CircuitSeq::remove_first_quantum_gate() {
488453
for (auto &circuit_gate : gates) {
489454
if (circuit_gate->gate->is_quantum_gate()) {
@@ -1307,6 +1272,35 @@ CircuitSeq::get_permuted_seq(const std::vector<int> &qubit_permutation,
13071272
return result;
13081273
}
13091274

1275+
std::unique_ptr<CircuitSeq>
1276+
CircuitSeq::get_suffix_seq(const std::unordered_set<CircuitGate *> &start_gates,
1277+
Context *ctx) const {
1278+
// For topological sort
1279+
std::unordered_map<CircuitGate *, int> gate_remaining_in_degree;
1280+
for (auto &gate : start_gates) {
1281+
gate_remaining_in_degree[gate] = 0; // ready to include
1282+
}
1283+
auto result = std::make_unique<CircuitSeq>(get_num_qubits());
1284+
// The result should be a subsequence of this circuit
1285+
for (auto &gate : gates) {
1286+
if (gate_remaining_in_degree.count(gate.get()) > 0 &&
1287+
gate_remaining_in_degree[gate.get()] <= 0) {
1288+
result->add_gate(gate.get(), ctx);
1289+
for (auto &output_wire : gate->output_wires) {
1290+
for (auto &output_gate : output_wire->output_gates) {
1291+
// For topological sort
1292+
if (gate_remaining_in_degree.count(output_gate) == 0) {
1293+
gate_remaining_in_degree[output_gate] =
1294+
output_gate->gate->get_num_qubits();
1295+
}
1296+
gate_remaining_in_degree[output_gate]--;
1297+
}
1298+
}
1299+
}
1300+
}
1301+
return result;
1302+
}
1303+
13101304
void CircuitSeq::clone_from(const CircuitSeq &other,
13111305
const std::vector<int> &qubit_permutation,
13121306
const std::vector<int> &param_permutation,
@@ -1592,20 +1586,22 @@ std::vector<int> CircuitSeq::first_quantum_gate_positions() const {
15921586
return result;
15931587
}
15941588

1589+
bool CircuitSeq::is_one_of_last_gates(CircuitGate *circuit_gate) const {
1590+
for (const auto &output_wire : circuit_gate->output_wires) {
1591+
if (outputs[output_wire->index] != output_wire) {
1592+
return false;
1593+
}
1594+
}
1595+
return true;
1596+
}
1597+
15951598
std::vector<CircuitGate *> CircuitSeq::last_quantum_gates() const {
15961599
std::vector<CircuitGate *> result;
15971600
for (const auto &circuit_gate : gates) {
15981601
if (circuit_gate->gate->is_parameter_gate()) {
15991602
continue;
16001603
}
1601-
bool all_output = true;
1602-
for (const auto &output_wire : circuit_gate->output_wires) {
1603-
if (outputs[output_wire->index] != output_wire) {
1604-
all_output = false;
1605-
break;
1606-
}
1607-
}
1608-
if (all_output) {
1604+
if (is_one_of_last_gates(circuit_gate.get())) {
16091605
result.push_back(circuit_gate.get());
16101606
}
16111607
}

src/quartz/circuitseq/circuitseq.h

+24-2
Original file line numberDiff line numberDiff line change
@@ -106,17 +106,23 @@ class CircuitSeq {
106106
bool remove_last_gate();
107107

108108
/**
109-
* Remove a quantum gate.
109+
* Remove a quantum gate in O(|get_num_gates()| - |gate_position|).
110110
* @param gate_position The position of the gate to be removed (0-indexed).
111111
* @return True iff the removal is successful.
112112
*/
113113
bool remove_gate(int gate_position);
114114
/**
115-
* Remove a quantum gate.
115+
* Remove a quantum gate in O(|get_num_gates()|).
116116
* @param circuit_gate The gate to be removed.
117117
* @return True iff the removal is successful.
118118
*/
119119
bool remove_gate(CircuitGate *circuit_gate);
120+
/**
121+
* Remove a quantum gate in O(|get_num_gates()| - |gate_position|).
122+
* @param circuit_gate The gate to be removed.
123+
* @return True iff the removal is successful.
124+
*/
125+
bool remove_gate_near_end(CircuitGate *circuit_gate);
120126
/**
121127
* Remove the first quantum gate (if there is one).
122128
* @return True iff the removal is successful.
@@ -279,6 +285,14 @@ class CircuitSeq {
279285
get_permuted_seq(const std::vector<int> &qubit_permutation,
280286
const std::vector<int> &input_param_permutation,
281287
Context *ctx) const;
288+
/**
289+
* Get a circuit with |start_gates| and all gates topologically after them.
290+
* @param start_gates The first gates at each qubit to include in the
291+
* circuit to return.
292+
*/
293+
[[nodiscard]] std::unique_ptr<CircuitSeq>
294+
get_suffix_seq(const std::unordered_set<CircuitGate *> &start_gates,
295+
Context *ctx) const;
282296

283297
/**
284298
* Get a circuit which replaces RZ gates with T, Tdg, S, Sdg, and Z gates.
@@ -307,6 +321,14 @@ class CircuitSeq {
307321
* @return The positions (0-indexed) of the first quantum gates.
308322
*/
309323
[[nodiscard]] std::vector<int> first_quantum_gate_positions() const;
324+
/**
325+
* Check if a quantum gate can appear at last in some topological
326+
* order of the CircuitSeq.
327+
* @param circuit_gate The pointer to a quantum gate in the circuit.
328+
* @return True iff the gate can appear at last in some topological
329+
* order of the CircuitSeq.
330+
*/
331+
[[nodiscard]] bool is_one_of_last_gates(CircuitGate *circuit_gate) const;
310332
/**
311333
* Returns quantum gates which can appear at last in some topological
312334
* order of the CircuitSeq.

src/quartz/generator/generator.cpp

+5-9
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,6 @@ bool Generator::generate(
1010
bool invoke_python_verifier, EquivalenceSet *equiv_set,
1111
bool unique_parameters, bool verbose,
1212
std::chrono::steady_clock::duration *record_verification_time) {
13-
std::filesystem::path this_file_path(__FILE__);
14-
auto quartz_root_path =
15-
this_file_path.parent_path().parent_path().parent_path().parent_path();
16-
1713
auto empty_dag = std::make_unique<CircuitSeq>(num_qubits);
1814
empty_dag->hash(ctx_); // generate other hash values
1915
std::vector<CircuitSeq *> dags_to_search(1, empty_dag.get());
@@ -61,7 +57,7 @@ bool Generator::generate(
6157
if (num_gates == max_num_quantum_gates) {
6258
break;
6359
}
64-
bool ret = dataset->save_json(ctx_, quartz_root_path.string() +
60+
bool ret = dataset->save_json(ctx_, kQuartzRootPath.string() +
6561
"/tmp_before_verify.json");
6662
assert(ret);
6763

@@ -70,10 +66,10 @@ bool Generator::generate(
7066
start = std::chrono::steady_clock::now();
7167
}
7268
std::string command_string =
73-
std::string("python ") + quartz_root_path.string() +
69+
std::string("python ") + kQuartzRootPath.string() +
7470
"/src/python/verifier/verify_equivalences.py " +
75-
quartz_root_path.string() + "/tmp_before_verify.json " +
76-
quartz_root_path.string() + "/tmp_after_verify.json";
71+
kQuartzRootPath.string() + "/tmp_before_verify.json " +
72+
kQuartzRootPath.string() + "/tmp_after_verify.json";
7773
system(command_string.c_str());
7874
if (record_verification_time) {
7975
auto end = std::chrono::steady_clock::now();
@@ -82,7 +78,7 @@ bool Generator::generate(
8278

8379
dags_to_search.clear();
8480
ret = equiv_set->load_json(
85-
ctx_, quartz_root_path.string() + "/tmp_after_verify.json",
81+
ctx_, kQuartzRootPath.string() + "/tmp_after_verify.json",
8682
/*from_verifier=*/true, &dags_to_search);
8783
assert(ret);
8884
for (auto &dag : dags_to_search) {

0 commit comments

Comments
 (0)