Skip to content

Commit 3014047

Browse files
authored
[Generator] Use row representation in RepGen (#173)
* [Generator] Use row representation to compare circuits * Add sort() to Dataset * Compare gate count first * common subcircuit pruning in repgen and topological equivalence * Generate Nam ECC Sets * Make gen_ecc_set able to run from any directory * Fix tmp json location * Use row representation to compare for the case not invoking Python verifier
1 parent 8923578 commit 3014047

File tree

9 files changed

+429
-93
lines changed

9 files changed

+429
-93
lines changed

src/quartz/circuitseq/circuitseq.cpp

+205-23
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <cassert>
1010
#include <charconv>
1111
#include <fstream>
12+
#include <optional>
1213
#include <queue>
1314
#include <unordered_set>
1415
#include <utility>
@@ -85,6 +86,78 @@ bool CircuitSeq::fully_equivalent(const CircuitSeq &other) const {
8586
return true;
8687
}
8788

89+
bool CircuitSeq::topologically_equivalent(const CircuitSeq &other) const {
90+
if (this == &other) {
91+
return true;
92+
}
93+
if (num_qubits != other.num_qubits) {
94+
return false;
95+
}
96+
if (wires.size() != other.wires.size() ||
97+
gates.size() != other.gates.size()) {
98+
return false;
99+
}
100+
// Mapping from this circuit to the other circuit
101+
std::unordered_map<CircuitWire *, CircuitWire *> wires_mapping;
102+
std::queue<CircuitWire *> wires_to_search;
103+
std::unordered_map<CircuitGate *, int> gate_remaining_in_degree;
104+
for (int i = 0; i < num_qubits; i++) {
105+
wires_mapping[wires[i].get()] = other.wires[i].get();
106+
wires_to_search.push(wires[i].get());
107+
}
108+
// Topological sort on this circuit
109+
while (!wires_to_search.empty()) {
110+
auto this_wire = wires_to_search.front();
111+
auto other_wire = wires_mapping[this_wire];
112+
assert(other_wire);
113+
wires_to_search.pop();
114+
if (this_wire->output_gates.size() != other_wire->output_gates.size()) {
115+
return false;
116+
}
117+
for (int i = 0; i < (int)this_wire->output_gates.size(); i++) {
118+
auto this_gate = this_wire->output_gates[i];
119+
if (gate_remaining_in_degree.count(this_gate) == 0) {
120+
// A new gate
121+
gate_remaining_in_degree[this_gate] = this_gate->gate->get_num_qubits();
122+
}
123+
if (!--gate_remaining_in_degree[this_gate]) {
124+
// Check if this gate is the same as the other gate
125+
auto other_gate = other_wire->output_gates[i];
126+
if (this_gate->gate->tp != other_gate->gate->tp) {
127+
return false;
128+
}
129+
if (this_gate->input_wires.size() != other_gate->input_wires.size() ||
130+
this_gate->output_wires.size() != other_gate->output_wires.size()) {
131+
return false;
132+
}
133+
for (int j = 0; j < (int)this_gate->input_wires.size(); j++) {
134+
if (this_gate->input_wires[j]->is_qubit()) {
135+
// The input wire must have been mapped.
136+
assert(wires_mapping.count(this_gate->input_wires[j]) != 0);
137+
if (wires_mapping[this_gate->input_wires[j]] !=
138+
other_gate->input_wires[j]) {
139+
return false;
140+
}
141+
} else {
142+
// parameters should not be mapped
143+
if (other_gate->input_wires[j] != this_gate->input_wires[j]) {
144+
return false;
145+
}
146+
}
147+
}
148+
// Map output wires
149+
for (int j = 0; j < (int)this_gate->output_wires.size(); j++) {
150+
assert(wires_mapping.count(this_gate->output_wires[j]) == 0);
151+
wires_mapping[this_gate->output_wires[j]] =
152+
other_gate->output_wires[j];
153+
wires_to_search.push(this_gate->output_wires[j]);
154+
}
155+
}
156+
}
157+
}
158+
return true; // equivalent
159+
}
160+
88161
bool CircuitSeq::fully_equivalent(Context *ctx, CircuitSeq &other) {
89162
if (hash(ctx) != other.hash(ctx)) {
90163
return false;
@@ -102,33 +175,106 @@ bool CircuitSeq::less_than(const CircuitSeq &other) const {
102175
if (get_num_gates() != other.get_num_gates()) {
103176
return get_num_gates() < other.get_num_gates();
104177
}
105-
for (int i = 0; i < (int)gates.size(); i++) {
106-
if (gates[i]->gate->tp != other.gates[i]->gate->tp) {
107-
return gates[i]->gate->tp < other.gates[i]->gate->tp;
108-
}
109-
assert(gates[i]->input_wires.size() == other.gates[i]->input_wires.size());
110-
assert(gates[i]->output_wires.size() ==
111-
other.gates[i]->output_wires.size());
112-
for (int j = 0; j < (int)gates[i]->input_wires.size(); j++) {
113-
if (gates[i]->input_wires[j]->is_qubit() !=
114-
other.gates[i]->input_wires[j]->is_qubit()) {
115-
return gates[i]->input_wires[j]->is_qubit();
178+
if (kUseRowRepresentationToCompare) {
179+
for (int i = 0; i < num_qubits; i++) {
180+
// Compare all gates on qubit i.
181+
auto this_ptr = wires[i].get();
182+
auto other_ptr = other.wires[i].get();
183+
std::optional<bool> compare_outcome = std::nullopt;
184+
while (this_ptr != outputs[i]) {
185+
if (other_ptr == other.outputs[i]) {
186+
// This circuit sequence has more gates on qubit i,
187+
// so this circuit is greater.
188+
return false;
189+
}
190+
assert(this_ptr->output_gates->size() == 1);
191+
assert(other_ptr->output_gates->size() == 1);
192+
auto this_gate = this_ptr->output_gates[0];
193+
auto other_gate = other_ptr->output_gates[0];
194+
if (!compare_outcome.has_value()) {
195+
if (this_gate->gate->tp != other_gate->gate->tp) {
196+
compare_outcome = this_gate->gate->tp < other_gate->gate->tp;
197+
} else {
198+
assert(this_gate->input_wires.size() ==
199+
other_gate->input_wires.size());
200+
assert(this_gate->output_wires.size() ==
201+
other_gate->output_wires.size());
202+
for (int j = 0; j < (int)this_gate->input_wires.size(); j++) {
203+
if (this_gate->input_wires[j]->is_qubit() !=
204+
other_gate->input_wires[j]->is_qubit()) {
205+
compare_outcome = this_gate->input_wires[j]->is_qubit();
206+
break;
207+
}
208+
if (this_gate->input_wires[j]->index !=
209+
other_gate->input_wires[j]->index) {
210+
compare_outcome = this_gate->input_wires[j]->index <
211+
other_gate->input_wires[j]->index;
212+
break;
213+
}
214+
}
215+
}
216+
}
217+
// No need to compare output wires for quantum gates.
218+
bool found_output_wire = false;
219+
for (auto &output_wire : this_gate->output_wires) {
220+
if (output_wire->index == i) {
221+
found_output_wire = true;
222+
this_ptr = output_wire;
223+
break;
224+
}
225+
}
226+
assert(found_output_wire);
227+
found_output_wire = false;
228+
for (auto &output_wire : other_gate->output_wires) {
229+
if (output_wire->index == i) {
230+
found_output_wire = true;
231+
other_ptr = output_wire;
232+
break;
233+
}
234+
}
235+
assert(found_output_wire);
236+
}
237+
if (other_ptr != other.outputs[i]) {
238+
// The other circuit sequence has more gates on qubit i,
239+
// so this circuit is less.
240+
return true;
116241
}
117-
if (gates[i]->input_wires[j]->index !=
118-
other.gates[i]->input_wires[j]->index) {
119-
return gates[i]->input_wires[j]->index <
120-
other.gates[i]->input_wires[j]->index;
242+
// Two circuit sequences have the same number of gates on qubit i.
243+
// Compare the contents.
244+
if (compare_outcome.has_value()) {
245+
return compare_outcome.value();
121246
}
122247
}
123-
for (int j = 0; j < (int)gates[i]->output_wires.size(); j++) {
124-
if (gates[i]->output_wires[j]->is_qubit() !=
125-
other.gates[i]->output_wires[j]->is_qubit()) {
126-
return gates[i]->output_wires[j]->is_qubit();
248+
} else {
249+
for (int i = 0; i < (int)gates.size(); i++) {
250+
if (gates[i]->gate->tp != other.gates[i]->gate->tp) {
251+
return gates[i]->gate->tp < other.gates[i]->gate->tp;
252+
}
253+
assert(gates[i]->input_wires.size() ==
254+
other.gates[i]->input_wires.size());
255+
assert(gates[i]->output_wires.size() ==
256+
other.gates[i]->output_wires.size());
257+
for (int j = 0; j < (int)gates[i]->input_wires.size(); j++) {
258+
if (gates[i]->input_wires[j]->is_qubit() !=
259+
other.gates[i]->input_wires[j]->is_qubit()) {
260+
return gates[i]->input_wires[j]->is_qubit();
261+
}
262+
if (gates[i]->input_wires[j]->index !=
263+
other.gates[i]->input_wires[j]->index) {
264+
return gates[i]->input_wires[j]->index <
265+
other.gates[i]->input_wires[j]->index;
266+
}
127267
}
128-
if (gates[i]->output_wires[j]->index !=
129-
other.gates[i]->output_wires[j]->index) {
130-
return gates[i]->output_wires[j]->index <
131-
other.gates[i]->output_wires[j]->index;
268+
for (int j = 0; j < (int)gates[i]->output_wires.size(); j++) {
269+
if (gates[i]->output_wires[j]->is_qubit() !=
270+
other.gates[i]->output_wires[j]->is_qubit()) {
271+
return gates[i]->output_wires[j]->is_qubit();
272+
}
273+
if (gates[i]->output_wires[j]->index !=
274+
other.gates[i]->output_wires[j]->index) {
275+
return gates[i]->output_wires[j]->index <
276+
other.gates[i]->output_wires[j]->index;
277+
}
132278
}
133279
}
134280
}
@@ -308,6 +454,20 @@ bool CircuitSeq::remove_last_gate() {
308454
return true;
309455
}
310456

457+
bool CircuitSeq::remove_gate(int gate_position) {
458+
if (gate_position < 0 || gate_position >= (int)gates.size()) {
459+
return false;
460+
}
461+
CircuitGate *circuit_gate = gates[gate_position].get();
462+
auto *gate = circuit_gate->gate;
463+
assert(gate->is_quantum_gate());
464+
remove_quantum_gate_from_graph(circuit_gate);
465+
// Remove the gate.
466+
gates.erase(gates.begin() + gate_position);
467+
hash_value_valid_ = false;
468+
return true;
469+
}
470+
311471
bool CircuitSeq::remove_gate(CircuitGate *circuit_gate) {
312472
auto gate_pos = std::find_if(
313473
gates.begin(), gates.end(),
@@ -1410,6 +1570,28 @@ std::vector<CircuitGate *> CircuitSeq::first_quantum_gates() const {
14101570
return result;
14111571
}
14121572

1573+
std::vector<int> CircuitSeq::first_quantum_gate_positions() const {
1574+
std::vector<int> result;
1575+
std::unordered_set<CircuitGate *> depend_on_other_gates;
1576+
depend_on_other_gates.reserve(gates.size());
1577+
for (int i = 0; i < (int)gates.size(); i++) {
1578+
CircuitGate *circuit_gate = gates[i].get();
1579+
if (circuit_gate->gate->is_parameter_gate()) {
1580+
continue;
1581+
}
1582+
if (depend_on_other_gates.find(circuit_gate) ==
1583+
depend_on_other_gates.end()) {
1584+
result.push_back(i);
1585+
}
1586+
for (const auto &output_wire : circuit_gate->output_wires) {
1587+
for (const auto &output_gate : output_wire->output_gates) {
1588+
depend_on_other_gates.insert(output_gate);
1589+
}
1590+
}
1591+
}
1592+
return result;
1593+
}
1594+
14131595
std::vector<CircuitGate *> CircuitSeq::last_quantum_gates() const {
14141596
std::vector<CircuitGate *> result;
14151597
for (const auto &circuit_gate : gates) {

src/quartz/circuitseq/circuitseq.h

+40-5
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ class CircuitSeq {
2828
* @return True iff two circuit sequences are fully equivalent.
2929
*/
3030
[[nodiscard]] bool fully_equivalent(const CircuitSeq &other) const;
31+
/**
32+
* Compare if two circuits are topologically equivalent.
33+
* X(Q0) X(Q1) and X(Q1) X(Q0) are topologically equivalent but not
34+
* fully equivalent.
35+
* @param other The other circuit sequence to be compared.
36+
* @return True iff two circuit sequences are topologically equivalent.
37+
*/
38+
[[nodiscard]] bool topologically_equivalent(const CircuitSeq &other) const;
3139
/**
3240
* Compute the hash value and compare if two circuit sequences are fully
3341
* equivalent including the hash value.
@@ -36,6 +44,15 @@ class CircuitSeq {
3644
* @return True iff two circuit sequences are fully equivalent.
3745
*/
3846
[[nodiscard]] bool fully_equivalent(Context *ctx, CircuitSeq &other);
47+
/**
48+
* Compare two circuit sequences first by the qubit count (fewer is less),
49+
* then by the gate count (fewer is less), then by the gate sequence.
50+
* If |kUseRowRepresentationToCompare| is true, compare the gates on qubit 0
51+
* first (fewer is less, then compare by the content), then qubit 1, ...
52+
* @param other The other circuit sequence to compare with.
53+
* @return True iff this circuit sequence is strictly less than the other
54+
* circuit sequence.
55+
*/
3956
[[nodiscard]] bool less_than(const CircuitSeq &other) const;
4057

4158
/**
@@ -90,7 +107,13 @@ class CircuitSeq {
90107

91108
/**
92109
* Remove a quantum gate.
93-
* @param circuit_gate the gate to be removed.
110+
* @param gate_position The position of the gate to be removed (0-indexed).
111+
* @return True iff the removal is successful.
112+
*/
113+
bool remove_gate(int gate_position);
114+
/**
115+
* Remove a quantum gate.
116+
* @param circuit_gate The gate to be removed.
94117
* @return True iff the removal is successful.
95118
*/
96119
bool remove_gate(CircuitGate *circuit_gate);
@@ -272,11 +295,23 @@ class CircuitSeq {
272295
*/
273296
std::unique_ptr<CircuitSeq> get_ccz_to_cx_rz(Context *ctx) const;
274297

275-
// Returns quantum gates which do not topologically depend on any other
276-
// quantum gates.
298+
/**
299+
* Returns quantum gates which do not topologically depend on any other
300+
* quantum gates.
301+
* @return The pointers to the first quantum gates.
302+
*/
277303
[[nodiscard]] std::vector<CircuitGate *> first_quantum_gates() const;
278-
// Returns quantum gates which can appear at last in some topological
279-
// order of the CircuitSeq.
304+
/**
305+
* Returns quantum gates which do not topologically depend on any other
306+
* quantum gates.
307+
* @return The positions (0-indexed) of the first quantum gates.
308+
*/
309+
[[nodiscard]] std::vector<int> first_quantum_gate_positions() const;
310+
/**
311+
* Returns quantum gates which can appear at last in some topological
312+
* order of the CircuitSeq.
313+
* @return The pointers to the last quantum gates.
314+
*/
280315
[[nodiscard]] std::vector<CircuitGate *> last_quantum_gates() const;
281316

282317
static bool same_gate(const CircuitSeq &seq1, int index1,

src/quartz/dataset/dataset.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,13 @@ int Dataset::normalize_to_canonical_representations(Context *ctx) {
161161
return num_removed;
162162
}
163163

164+
void Dataset::sort() {
165+
for (auto &it : dataset) {
166+
std::sort(it.second.begin(), it.second.end(),
167+
UniquePtrCircuitSeqComparator());
168+
}
169+
}
170+
164171
bool Dataset::insert(Context *ctx, std::unique_ptr<CircuitSeq> dag) {
165172
const auto hash_value = dag->hash(ctx);
166173
bool ret = dataset.count(hash_value) == 0;

src/quartz/dataset/dataset.h

+5
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ class Dataset {
2424
// Return the number of DAGs removed.
2525
int normalize_to_canonical_representations(Context *ctx);
2626

27+
/**
28+
* Sort the circuits with the same hash value by CircuitSeq::less_than().
29+
*/
30+
void sort();
31+
2732
// This function runs in O(1).
2833
[[nodiscard]] int num_hash_values() const;
2934

0 commit comments

Comments
 (0)