Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[doc] Update comments for pattern matching #196

Merged
merged 1 commit into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/quartz/tasograph/substitution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -659,10 +659,10 @@ bool GraphXfer::map_output(const TensorX &src, const TensorX &dst) {
bool GraphXfer::can_match(OpX *srcOp, Op op, const Graph *graph) const {
// This function takes in an OpX, and will check all its input and
// output tensors. If there are tensors connecting it with other already
// mapped ops, check whether these gates exists in the given Graph. No
// mapped ops, check whether these gates exist in the given Graph. No
// need to call this function with topological order. Because once both
// the src op and the dst op are mapped, the edge connecting them will
// be checked. This gauarentee that every gates are checked at the end.
// be checked. This guarantees that all gates are checked at the end.

// Check gate type
if (op == Op::INVALID_OP)
Expand Down
21 changes: 13 additions & 8 deletions src/quartz/tasograph/substitution.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
#include "../context/rule_parser.h"
#include "../gate/gate_utils.h"
#include "../parser/qasm_parser.h"
#include "assert.h"
#include "quartz/circuitseq/circuitseq.h"
#include "tasograph.h"

#include <cassert>
#include <ostream>
#include <queue>

Expand All @@ -16,9 +16,11 @@ namespace quartz {
class OpX;
class GraphXfer;

/**
* A TensorX represents an input/output edge.
*/
struct TensorX {
// A TensorX represnet an output edge
TensorX(void) : op(NULL), idx(0) {}
TensorX() : op(NULL), idx(0) {}
TensorX(OpX *_op, int _idx) : op(_op), idx(_idx) {}
Tensor to_edge(const GraphXfer *xfer) const;
OpX *op; // The op that outputs this tensor
Expand Down Expand Up @@ -48,6 +50,9 @@ class TensorXHash {
}
};

/**
* An OpX represents a node in the graph.
*/
class OpX {
public:
OpX(const OpX &_op);
Expand All @@ -66,7 +71,7 @@ class GraphCompare {
GraphCompare() {
cost_function_ = [](Graph *graph) { return graph->total_cost(); };
}
GraphCompare(const std::function<float(Graph *)> &cost_function)
explicit GraphCompare(const std::function<float(Graph *)> &cost_function)
: cost_function_(cost_function) {}
bool operator()(const std::shared_ptr<Graph> &lhs,
const std::shared_ptr<Graph> &rhs) {
Expand All @@ -83,7 +88,7 @@ class GraphXfer {
GraphXfer(Context *src_ctx, Context *dst_ctx, Context *union_ctx,
const CircuitSeq *src_graph, const CircuitSeq *dst_graph);
bool src_graph_connected(CircuitSeq *src_graph);
TensorX new_tensor(void);
TensorX new_tensor();
bool is_input_qubit(const OpX *opx, int idx) const;
bool is_input_parameter(const OpX *opx, int idx) const;
bool is_symbolic_input_parameter(const OpX *opx, int idx) const;
Expand All @@ -101,9 +106,9 @@ class GraphXfer {
bool create_new_operator(const OpX *opx, Op &op);
int num_src_op();
int num_dst_op();
std::string to_str(std::vector<OpX *> const &v) const;
std::string src_str() const;
std::string dst_str() const;
[[nodiscard]] std::string to_str(std::vector<OpX *> const &v) const;
[[nodiscard]] std::string src_str() const;
[[nodiscard]] std::string dst_str() const;
// TODO: not implemented
// std::string to_qasm(std::vector<OpX *> const &v) const;

Expand Down
6 changes: 3 additions & 3 deletions src/quartz/tasograph/tasograph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ Op Graph::add_qubit(int qubit_idx) {
return op;
}

Op Graph::add_parameter(const ParamType p) {
Op Graph::add_parameter(const ParamType &p) {
Gate *gate = context->get_gate(GateType::input_param);
auto guid = context->next_global_unique_id();
Op op(guid, gate);
Expand Down Expand Up @@ -1537,7 +1537,7 @@ std::shared_ptr<Graph> Graph::from_qasm_file(Context *ctx,
}

std::shared_ptr<Graph> Graph::from_qasm_str(Context *ctx,
const std::string qasm_str) {
const std::string &qasm_str) {
std::stringstream sstream(qasm_str);
return _from_qasm_stream(ctx, sstream);
}
Expand Down Expand Up @@ -2396,7 +2396,7 @@ bool Graph::_pattern_matching(
}
}
if (!fail) {
// Check qubit consistancy
// Check qubit consistency
std::set<int> qubits;
for (auto it = xfer->mappedInputs.cbegin(); it != xfer->mappedInputs.cend();
++it) {
Expand Down
62 changes: 38 additions & 24 deletions src/quartz/tasograph/tasograph.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ bool equal_to_2k_pi(double d);

class Op {
public:
Op(void);
Op();
Op(size_t _guid, Gate *_ptr) : guid(_guid), ptr(_ptr) {}
inline bool operator==(const Op &b) const {
if (guid != b.guid)
Expand Down Expand Up @@ -146,13 +146,13 @@ class PosCompare {

class Tensor {
public:
Tensor(void);
Tensor();
int idx;
Op op;
};

struct Edge {
Edge(void);
Edge();
Edge(const Op &_srcOp, const Op &_dstOp, int _srcIdx, int _dstIdx);
Op srcOp, dstOp;
int srcIdx, dstIdx;
Expand All @@ -177,21 +177,22 @@ class OpX;

class Graph {
public:
Graph(Context *ctx);
explicit Graph(Context *ctx);
Graph(Context *ctx, const CircuitSeq *seq);
Graph(const Graph &graph);
[[nodiscard]] std::unique_ptr<CircuitSeq> to_circuit_sequence() const;
void _construct_pos_2_logical_qubit();
void add_edge(const Op &srcOp, const Op &dstOp, int srcIdx, int dstIdx);
bool has_edge(const Op &srcOp, const Op &dstOp, int srcIdx, int dstIdx) const;
[[nodiscard]] bool has_edge(const Op &srcOp, const Op &dstOp, int srcIdx,
int dstIdx) const;
Op add_qubit(int qubit_idx);
Op add_parameter(const ParamType p);
Op add_parameter(const ParamType &p);
Op new_gate(GateType gt);
bool has_loop() const;
[[nodiscard]] bool has_loop() const;
size_t hash();
bool equal(const Graph &other) const;
[[nodiscard]] bool equal(const Graph &other) const;
bool check_correctness();
int specific_gate_count(GateType gate_type) const;
[[nodiscard]] int specific_gate_count(GateType gate_type) const;
[[nodiscard]] float total_cost() const;
[[nodiscard]] int gate_count() const;
[[nodiscard]] int circuit_depth() const;
Expand Down Expand Up @@ -273,7 +274,8 @@ class Graph {
bool continue_storing_all_steps = false);
void constant_and_rotation_elimination();
void rotation_merging(GateType target_rotation);
std::string to_qasm(bool print_result = false, bool print_guid = false) const;
[[nodiscard]] std::string to_qasm(bool print_result = false,
bool print_guid = false) const;
void to_qasm(const std::string &save_filename, bool print_result,
bool print_guid) const;
template <class _CharT, class _Traits>
Expand All @@ -283,10 +285,10 @@ class Graph {
static std::shared_ptr<Graph> from_qasm_file(Context *ctx,
const std::string &filename);
static std::shared_ptr<Graph> from_qasm_str(Context *ctx,
const std::string qasm_str);
const std::string &qasm_str);
void draw_circuit(const std::string &qasm_str,
const std::string &save_filename);
size_t get_num_qubits() const;
[[nodiscard]] size_t get_num_qubits() const;
void print_qubit_ops();
std::shared_ptr<Graph> toffoli_flip_greedy(GateType target_rotation,
GraphXfer *xfer,
Expand All @@ -298,9 +300,9 @@ class Graph {
toffoli_flip_by_instruction(GateType target_rotation, GraphXfer *xfer,
GraphXfer *inverse_xfer,
std::vector<int> instruction);
std::vector<size_t> appliable_xfers(Op op,
const std::vector<GraphXfer *> &) const;
std::vector<size_t>
[[nodiscard]] std::vector<size_t>
appliable_xfers(Op op, const std::vector<GraphXfer *> &) const;
[[nodiscard]] std::vector<size_t>
appliable_xfers_parallel(Op op, const std::vector<GraphXfer *> &) const;
bool xfer_appliable(GraphXfer *xfer, Op op) const;
std::shared_ptr<Graph> apply_xfer(GraphXfer *xfer, Op op,
Expand All @@ -318,16 +320,16 @@ class Graph {
std::shared_ptr<Graph> ccz_flip_greedy_rz();
std::shared_ptr<Graph> ccz_flip_greedy_u1();
bool _loop_check_after_matching(GraphXfer *xfer) const;
std::shared_ptr<Graph>
[[nodiscard]] std::shared_ptr<Graph>
subgraph(const std::unordered_set<Op, OpHash> &ops) const;
std::vector<std::shared_ptr<Graph>>
[[nodiscard]] std::vector<std::shared_ptr<Graph>>
topology_partition(const int partition_gate_count) const;
/**
* Return the parameter value if the Op is a constant parameter,
* or return 0 otherwise.
*/
ParamType get_param_value(const Op &op) const;
bool param_has_value(const Op &op) const;
[[nodiscard]] ParamType get_param_value(const Op &op) const;
[[nodiscard]] bool param_has_value(const Op &op) const;

private:
void replace_node(Op oldOp, Op newOp);
Expand All @@ -344,11 +346,23 @@ class Graph {
bool moveable(GateType tp);
bool move_forward(Pos &pos, bool left);
bool merge_2_rotation_op(Op op_0, Op op_1);
// The common core part of the API xfer_appliable, apply_xfer, and
// apply_xfer_and_track_node. Matches the src dag of xfer to the local dag
// in the circuit whose topological-order root is op. If failed, it
// automatically unmaps the matched nodes. Otherwise, the caller should
// unmap the matched nodes after their work is done.
/**
* The common core part of the API xfer_applicable, apply_xfer, and
* apply_xfer_and_track_node. Matches the src dag of xfer to the local dag
* in the circuit whose topological-order root is op. If failed, it
* automatically unmaps the matched nodes. Otherwise, the caller should
* unmap the matched nodes after their work is done.
* Because the src dag is connected, this function traverses the src dag
* and try to match the local dag in a connected way (not necessarily in
* topological order). For example, the pattern H(q0) H(q1) CX(q0, q1)
* may be matched in the order
* H(q0) (match qubit 0) -> CX(q0, q1) (match qubit 1) <- H(q1).
* @param xfer The xfer with a src dag and a target dag.
* @param op The node to begin with in the local dag.
* @param matched_opx_op_pairs_dq A deque to store matched nodes from
* the src dag to the local dag.
* @return If the match is successful.
*/
bool _pattern_matching(
GraphXfer *xfer, Op op,
std::deque<std::pair<OpX *, Op>> &matched_opx_op_pairs_dq) const;
Expand Down
Loading