Skip to content

Commit 571a283

Browse files
authored
[doc] Update comments for pattern matching (#196)
1 parent 303ad6a commit 571a283

File tree

4 files changed

+56
-37
lines changed

4 files changed

+56
-37
lines changed

src/quartz/tasograph/substitution.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -659,10 +659,10 @@ bool GraphXfer::map_output(const TensorX &src, const TensorX &dst) {
659659
bool GraphXfer::can_match(OpX *srcOp, Op op, const Graph *graph) const {
660660
// This function takes in an OpX, and will check all its input and
661661
// output tensors. If there are tensors connecting it with other already
662-
// mapped ops, check whether these gates exists in the given Graph. No
662+
// mapped ops, check whether these gates exist in the given Graph. No
663663
// need to call this function with topological order. Because once both
664664
// the src op and the dst op are mapped, the edge connecting them will
665-
// be checked. This gauarentee that every gates are checked at the end.
665+
// be checked. This guarantees that all gates are checked at the end.
666666

667667
// Check gate type
668668
if (op == Op::INVALID_OP)

src/quartz/tasograph/substitution.h

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
#include "../context/rule_parser.h"
55
#include "../gate/gate_utils.h"
66
#include "../parser/qasm_parser.h"
7-
#include "assert.h"
87
#include "quartz/circuitseq/circuitseq.h"
98
#include "tasograph.h"
109

10+
#include <cassert>
1111
#include <ostream>
1212
#include <queue>
1313

@@ -16,9 +16,11 @@ namespace quartz {
1616
class OpX;
1717
class GraphXfer;
1818

19+
/**
20+
* A TensorX represents an input/output edge.
21+
*/
1922
struct TensorX {
20-
// A TensorX represnet an output edge
21-
TensorX(void) : op(NULL), idx(0) {}
23+
TensorX() : op(NULL), idx(0) {}
2224
TensorX(OpX *_op, int _idx) : op(_op), idx(_idx) {}
2325
Tensor to_edge(const GraphXfer *xfer) const;
2426
OpX *op; // The op that outputs this tensor
@@ -48,6 +50,9 @@ class TensorXHash {
4850
}
4951
};
5052

53+
/**
54+
* An OpX represents a node in the graph.
55+
*/
5156
class OpX {
5257
public:
5358
OpX(const OpX &_op);
@@ -66,7 +71,7 @@ class GraphCompare {
6671
GraphCompare() {
6772
cost_function_ = [](Graph *graph) { return graph->total_cost(); };
6873
}
69-
GraphCompare(const std::function<float(Graph *)> &cost_function)
74+
explicit GraphCompare(const std::function<float(Graph *)> &cost_function)
7075
: cost_function_(cost_function) {}
7176
bool operator()(const std::shared_ptr<Graph> &lhs,
7277
const std::shared_ptr<Graph> &rhs) {
@@ -83,7 +88,7 @@ class GraphXfer {
8388
GraphXfer(Context *src_ctx, Context *dst_ctx, Context *union_ctx,
8489
const CircuitSeq *src_graph, const CircuitSeq *dst_graph);
8590
bool src_graph_connected(CircuitSeq *src_graph);
86-
TensorX new_tensor(void);
91+
TensorX new_tensor();
8792
bool is_input_qubit(const OpX *opx, int idx) const;
8893
bool is_input_parameter(const OpX *opx, int idx) const;
8994
bool is_symbolic_input_parameter(const OpX *opx, int idx) const;
@@ -101,9 +106,9 @@ class GraphXfer {
101106
bool create_new_operator(const OpX *opx, Op &op);
102107
int num_src_op();
103108
int num_dst_op();
104-
std::string to_str(std::vector<OpX *> const &v) const;
105-
std::string src_str() const;
106-
std::string dst_str() const;
109+
[[nodiscard]] std::string to_str(std::vector<OpX *> const &v) const;
110+
[[nodiscard]] std::string src_str() const;
111+
[[nodiscard]] std::string dst_str() const;
107112
// TODO: not implemented
108113
// std::string to_qasm(std::vector<OpX *> const &v) const;
109114

src/quartz/tasograph/tasograph.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ Op Graph::add_qubit(int qubit_idx) {
287287
return op;
288288
}
289289

290-
Op Graph::add_parameter(const ParamType p) {
290+
Op Graph::add_parameter(const ParamType &p) {
291291
Gate *gate = context->get_gate(GateType::input_param);
292292
auto guid = context->next_global_unique_id();
293293
Op op(guid, gate);
@@ -1537,7 +1537,7 @@ std::shared_ptr<Graph> Graph::from_qasm_file(Context *ctx,
15371537
}
15381538

15391539
std::shared_ptr<Graph> Graph::from_qasm_str(Context *ctx,
1540-
const std::string qasm_str) {
1540+
const std::string &qasm_str) {
15411541
std::stringstream sstream(qasm_str);
15421542
return _from_qasm_stream(ctx, sstream);
15431543
}
@@ -2396,7 +2396,7 @@ bool Graph::_pattern_matching(
23962396
}
23972397
}
23982398
if (!fail) {
2399-
// Check qubit consistancy
2399+
// Check qubit consistency
24002400
std::set<int> qubits;
24012401
for (auto it = xfer->mappedInputs.cbegin(); it != xfer->mappedInputs.cend();
24022402
++it) {

src/quartz/tasograph/tasograph.h

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ bool equal_to_2k_pi(double d);
2525

2626
class Op {
2727
public:
28-
Op(void);
28+
Op();
2929
Op(size_t _guid, Gate *_ptr) : guid(_guid), ptr(_ptr) {}
3030
inline bool operator==(const Op &b) const {
3131
if (guid != b.guid)
@@ -146,13 +146,13 @@ class PosCompare {
146146

147147
class Tensor {
148148
public:
149-
Tensor(void);
149+
Tensor();
150150
int idx;
151151
Op op;
152152
};
153153

154154
struct Edge {
155-
Edge(void);
155+
Edge();
156156
Edge(const Op &_srcOp, const Op &_dstOp, int _srcIdx, int _dstIdx);
157157
Op srcOp, dstOp;
158158
int srcIdx, dstIdx;
@@ -177,21 +177,22 @@ class OpX;
177177

178178
class Graph {
179179
public:
180-
Graph(Context *ctx);
180+
explicit Graph(Context *ctx);
181181
Graph(Context *ctx, const CircuitSeq *seq);
182182
Graph(const Graph &graph);
183183
[[nodiscard]] std::unique_ptr<CircuitSeq> to_circuit_sequence() const;
184184
void _construct_pos_2_logical_qubit();
185185
void add_edge(const Op &srcOp, const Op &dstOp, int srcIdx, int dstIdx);
186-
bool has_edge(const Op &srcOp, const Op &dstOp, int srcIdx, int dstIdx) const;
186+
[[nodiscard]] bool has_edge(const Op &srcOp, const Op &dstOp, int srcIdx,
187+
int dstIdx) const;
187188
Op add_qubit(int qubit_idx);
188-
Op add_parameter(const ParamType p);
189+
Op add_parameter(const ParamType &p);
189190
Op new_gate(GateType gt);
190-
bool has_loop() const;
191+
[[nodiscard]] bool has_loop() const;
191192
size_t hash();
192-
bool equal(const Graph &other) const;
193+
[[nodiscard]] bool equal(const Graph &other) const;
193194
bool check_correctness();
194-
int specific_gate_count(GateType gate_type) const;
195+
[[nodiscard]] int specific_gate_count(GateType gate_type) const;
195196
[[nodiscard]] float total_cost() const;
196197
[[nodiscard]] int gate_count() const;
197198
[[nodiscard]] int circuit_depth() const;
@@ -273,7 +274,8 @@ class Graph {
273274
bool continue_storing_all_steps = false);
274275
void constant_and_rotation_elimination();
275276
void rotation_merging(GateType target_rotation);
276-
std::string to_qasm(bool print_result = false, bool print_guid = false) const;
277+
[[nodiscard]] std::string to_qasm(bool print_result = false,
278+
bool print_guid = false) const;
277279
void to_qasm(const std::string &save_filename, bool print_result,
278280
bool print_guid) const;
279281
template <class _CharT, class _Traits>
@@ -283,10 +285,10 @@ class Graph {
283285
static std::shared_ptr<Graph> from_qasm_file(Context *ctx,
284286
const std::string &filename);
285287
static std::shared_ptr<Graph> from_qasm_str(Context *ctx,
286-
const std::string qasm_str);
288+
const std::string &qasm_str);
287289
void draw_circuit(const std::string &qasm_str,
288290
const std::string &save_filename);
289-
size_t get_num_qubits() const;
291+
[[nodiscard]] size_t get_num_qubits() const;
290292
void print_qubit_ops();
291293
std::shared_ptr<Graph> toffoli_flip_greedy(GateType target_rotation,
292294
GraphXfer *xfer,
@@ -298,9 +300,9 @@ class Graph {
298300
toffoli_flip_by_instruction(GateType target_rotation, GraphXfer *xfer,
299301
GraphXfer *inverse_xfer,
300302
std::vector<int> instruction);
301-
std::vector<size_t> appliable_xfers(Op op,
302-
const std::vector<GraphXfer *> &) const;
303-
std::vector<size_t>
303+
[[nodiscard]] std::vector<size_t>
304+
appliable_xfers(Op op, const std::vector<GraphXfer *> &) const;
305+
[[nodiscard]] std::vector<size_t>
304306
appliable_xfers_parallel(Op op, const std::vector<GraphXfer *> &) const;
305307
bool xfer_appliable(GraphXfer *xfer, Op op) const;
306308
std::shared_ptr<Graph> apply_xfer(GraphXfer *xfer, Op op,
@@ -318,16 +320,16 @@ class Graph {
318320
std::shared_ptr<Graph> ccz_flip_greedy_rz();
319321
std::shared_ptr<Graph> ccz_flip_greedy_u1();
320322
bool _loop_check_after_matching(GraphXfer *xfer) const;
321-
std::shared_ptr<Graph>
323+
[[nodiscard]] std::shared_ptr<Graph>
322324
subgraph(const std::unordered_set<Op, OpHash> &ops) const;
323-
std::vector<std::shared_ptr<Graph>>
325+
[[nodiscard]] std::vector<std::shared_ptr<Graph>>
324326
topology_partition(const int partition_gate_count) const;
325327
/**
326328
* Return the parameter value if the Op is a constant parameter,
327329
* or return 0 otherwise.
328330
*/
329-
ParamType get_param_value(const Op &op) const;
330-
bool param_has_value(const Op &op) const;
331+
[[nodiscard]] ParamType get_param_value(const Op &op) const;
332+
[[nodiscard]] bool param_has_value(const Op &op) const;
331333

332334
private:
333335
void replace_node(Op oldOp, Op newOp);
@@ -344,11 +346,23 @@ class Graph {
344346
bool moveable(GateType tp);
345347
bool move_forward(Pos &pos, bool left);
346348
bool merge_2_rotation_op(Op op_0, Op op_1);
347-
// The common core part of the API xfer_appliable, apply_xfer, and
348-
// apply_xfer_and_track_node. Matches the src dag of xfer to the local dag
349-
// in the circuit whose topological-order root is op. If failed, it
350-
// automatically unmaps the matched nodes. Otherwise, the caller should
351-
// unmap the matched nodes after their work is done.
349+
/**
350+
* The common core part of the API xfer_applicable, apply_xfer, and
351+
* apply_xfer_and_track_node. Matches the src dag of xfer to the local dag
352+
* in the circuit whose topological-order root is op. If failed, it
353+
* automatically unmaps the matched nodes. Otherwise, the caller should
354+
* unmap the matched nodes after their work is done.
355+
* Because the src dag is connected, this function traverses the src dag
356+
* and try to match the local dag in a connected way (not necessarily in
357+
* topological order). For example, the pattern H(q0) H(q1) CX(q0, q1)
358+
* may be matched in the order
359+
* H(q0) (match qubit 0) -> CX(q0, q1) (match qubit 1) <- H(q1).
360+
* @param xfer The xfer with a src dag and a target dag.
361+
* @param op The node to begin with in the local dag.
362+
* @param matched_opx_op_pairs_dq A deque to store matched nodes from
363+
* the src dag to the local dag.
364+
* @return If the match is successful.
365+
*/
352366
bool _pattern_matching(
353367
GraphXfer *xfer, Op op,
354368
std::deque<std::pair<OpX *, Op>> &matched_opx_op_pairs_dq) const;

0 commit comments

Comments
 (0)