Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Optimizer] Support context shift with multiple rules with the same source gate type #168

Merged
merged 2 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions circuit/example-circuits/rz_multiples.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
OPENQASM 2.0;
include "qelib1.inc";
qreg q[1];
rz(pi*0.250000) q[0];
rz(pi*-0.250000) q[0];
rz(pi*-0.500000) q[0];
rz(pi*0.500000) q[0];
rz(pi*0.750000) q[0];
rz(pi*-0.750000) q[0];
rz(pi*1.000000) q[0];
rz(pi*0.000000) q[0];
121 changes: 58 additions & 63 deletions src/quartz/context/rule_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ class Command {
Command(const std::string &str_command) {
std::istringstream iss(str_command);
std::string gate_tp;
while (gate_tp == "") {
while (gate_tp.empty()) {
getline(iss, gate_tp, ' ');
}
tp = to_gate_type(gate_tp);

std::string input;
while (!iss.eof()) {
getline(iss, input, ' ');
if (input == "")
if (input.empty())
continue;
if (input[0] == 'q') {
qubit_idx.push_back(stoi(input.substr(1, input.size() - 1)));
Expand Down Expand Up @@ -79,12 +79,12 @@ class Command {

class RuleParser {
public:
RuleParser(std::vector<std::string> rules_) {
for (auto rule : rules_) {
explicit RuleParser(const std::vector<std::string> &rules_) {
for (const auto &rule : rules_) {
std::string gate_name;
auto pos = rule.find("=");
auto pos = rule.find('=');
assert(pos != rule.npos);
auto src_cmd = Command(rule.substr(0, rule.find("=")));
auto src_cmd = Command(rule.substr(0, rule.find('=')));
GateType tp = src_cmd.get_gate_type();

std::istringstream iss1(rule.substr(pos + 1));
Expand All @@ -94,7 +94,7 @@ class RuleParser {
getline(iss1, input, ';');
// std::cout << input << std::endl;
if (!input.empty()) {
cmds.push_back(Command(input));
cmds.emplace_back(input);
}
}
std::set<GateType> tp_set;
Expand All @@ -106,93 +106,88 @@ class RuleParser {
std::vector<std::pair<
Command, std::pair<std::vector<Command>, std::set<GateType>>>>
tp_rules;
tp_rules.push_back(
std::make_pair(src_cmd, std::make_pair(cmds, tp_set)));
tp_rules.emplace_back(src_cmd, std::make_pair(cmds, tp_set));
rules[tp] = tp_rules;
} else {
rules[tp].push_back(
std::make_pair(src_cmd, std::make_pair(cmds, tp_set)));
rules[tp].emplace_back(src_cmd, std::make_pair(cmds, tp_set));
}
}
}

bool find_convert_commands(Context *ctx, const GateType tp, Command &src_cmd,
std::vector<Command> &cmds) {
/**
* Find all conversion commands for a gate type.
* @param ctx The destination context.
* @param tp The gate type.
* @param src_cmd Return the source commands.
* @param cmds Return the target commands.
* @return The number of commands returned.
*/
int find_convert_commands(Context *ctx, const GateType tp,
std::vector<Command> &src_cmd,
std::vector<std::vector<Command>> &cmds) {
src_cmd.clear();
cmds.clear();
if (rules.find(tp) == rules.end()) {
std::cout << "No rules found to fit gate to context" << std::endl;
return false;
std::cerr
<< "No rules with the same gate type found to fit gate to context."
<< std::endl;
return 0;
}

std::set<GateType> supported_gate_tp_set(ctx->get_supported_gates().begin(),
ctx->get_supported_gates().end());
std::vector<
std::pair<Command, std::pair<std::vector<Command>, std::set<GateType>>>>
cmds_list = rules[tp];
for (auto cmds_info : cmds_list) {
for (const auto &cmds_info : cmds_list) {
std::set<GateType> used_gate_tp_set = cmds_info.second.second;
bool not_found = false;
for (auto it = used_gate_tp_set.begin(); it != used_gate_tp_set.end();
++it) {
if (supported_gate_tp_set.find(*it) == supported_gate_tp_set.end()) {
for (auto it : used_gate_tp_set) {
if (supported_gate_tp_set.find(it) == supported_gate_tp_set.end()) {
not_found = true;
break;
}
}
if (!not_found) {
cmds = cmds_info.second.first;
src_cmd = cmds_info.first;
// for (auto cmd : cmds) {
// cmd.print();
// }
return true;
src_cmd.push_back(cmds_info.first);
cmds.push_back(cmds_info.second.first);
}
}
std::cout << "No rules found to fit gate to context" << std::endl;
return false;
return (int)src_cmd.size();
}

public:
static std::pair<RuleParser *, RuleParser *> ccz_cx_rz_rules() {
RuleParser *rule_0 =
new RuleParser({"ccz q0 q1 q2 = cx q1 q2; rz q2 -0.25pi; cx q0 q2; rz "
"q2 0.25pi; cx q1 q2; rz q2 -0.25pi; cx "
"q0 q2; cx q0 q1; rz q1 -0.25pi; cx q0 q1; rz q0 "
"0.25pi; rz q1 0.25pi; rz q2 0.25pi;"});
RuleParser *rule_1 =
new RuleParser({"ccz q0 q1 q2 = cx q1 q2; rz q2 0.25pi; cx q0 q2; rz "
"q2 -0.25pi; cx q1 q2; rz q2 0.25pi; cx "
"q0 q2; cx q0 q1; rz q1 0.25pi; cx q0 q1; rz q0 "
"-0.25pi; rz q1 -0.25pi; rz q2 -0.25pi;"});
return std::make_pair(rule_0, rule_1);
static RuleParser ccz_cx_rz_rules() {
return RuleParser({"ccz q0 q1 q2 = cx q1 q2; rz q2 -0.25pi; cx q0 q2; rz "
"q2 0.25pi; cx q1 q2; rz q2 -0.25pi; cx "
"q0 q2; cx q0 q1; rz q1 -0.25pi; cx q0 q1; rz q0 "
"0.25pi; rz q1 0.25pi; rz q2 0.25pi;",
"ccz q0 q1 q2 = cx q1 q2; rz q2 0.25pi; cx q0 q2; rz "
"q2 -0.25pi; cx q1 q2; rz q2 0.25pi; cx "
"q0 q2; cx q0 q1; rz q1 0.25pi; cx q0 q1; rz q0 "
"-0.25pi; rz q1 -0.25pi; rz q2 -0.25pi;"});
}

static std::pair<RuleParser *, RuleParser *> ccz_cx_u1_rules() {
RuleParser *rule_0 =
new RuleParser({"ccz q0 q1 q2 = cx q1 q2; u1 q2 -0.25pi; cx q0 q2; u1 "
"q2 0.25pi; cx q1 q2; u1 q2 -0.25pi; cx "
"q0 q2; cx q0 q1; u1 q1 -0.25pi; cx q0 q1; u1 q0 "
"0.25pi; u1 q1 0.25pi; u1 q2 0.25pi;"});
RuleParser *rule_1 =
new RuleParser({"ccz q0 q1 q2 = cx q1 q2; u1 q2 0.25pi; cx q0 q2; u1 "
"q2 -0.25pi; cx q1 q2; u1 q2 0.25pi; cx "
"q0 q2; cx q0 q1; u1 q1 0.25pi; cx q0 q1; u1 q0 "
"-0.25pi; u1 q1 -0.25pi; u1 q2 -0.25pi;"});
return std::make_pair(rule_0, rule_1);
static RuleParser ccz_cx_u1_rules() {
return RuleParser({"ccz q0 q1 q2 = cx q1 q2; u1 q2 -0.25pi; cx q0 q2; u1 "
"q2 0.25pi; cx q1 q2; u1 q2 -0.25pi; cx "
"q0 q2; cx q0 q1; u1 q1 -0.25pi; cx q0 q1; u1 q0 "
"0.25pi; u1 q1 0.25pi; u1 q2 0.25pi;",
"ccz q0 q1 q2 = cx q1 q2; u1 q2 0.25pi; cx q0 q2; u1 "
"q2 -0.25pi; cx q1 q2; u1 q2 0.25pi; cx "
"q0 q2; cx q0 q1; u1 q1 0.25pi; cx q0 q1; u1 q0 "
"-0.25pi; u1 q1 -0.25pi; u1 q2 -0.25pi;"});
}

static std::pair<RuleParser *, RuleParser *> ccz_cx_t_rules() {
RuleParser *rule_0 =
new RuleParser({"ccz q0 q1 q2 = cx q1 q2; tdg q2; cx q0 q2; t "
"q2; cx q1 q2; tdg q2; cx "
"q0 q2; cx q0 q1; tdg q1; cx q0 q1; t q0"
"; t q1; t q2;"});
RuleParser *rule_1 =
new RuleParser({"ccz q0 q1 q2 = cx q1 q2; t q2; cx q0 q2; tdg "
"q2; cx q1 q2; t q2; cx "
"q0 q2; cx q0 q1; t q1; cx q0 q1; tdg q0"
"; tdg q1; tdg q2;"});
return std::make_pair(rule_0, rule_1);
static RuleParser ccz_cx_t_rules() {
return RuleParser({"ccz q0 q1 q2 = cx q1 q2; tdg q2; cx q0 q2; t "
"q2; cx q1 q2; tdg q2; cx "
"q0 q2; cx q0 q1; tdg q1; cx q0 q1; t q0"
"; t q1; t q2;",
"ccz q0 q1 q2 = cx q1 q2; t q2; cx q0 q2; tdg "
"q2; cx q1 q2; t q2; cx "
"q0 q2; cx q0 q1; t q1; cx q0 q1; tdg q0"
"; tdg q1; tdg q2;"});
}

private:
Expand Down
84 changes: 35 additions & 49 deletions src/quartz/tasograph/substitution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,40 +382,38 @@ GraphXfer *GraphXfer::create_GraphXfer_from_qasm_str(
delete src_dag;
delete dst_dag;
return graphXfer;

return graphXfer;
}

GraphXfer *
GraphXfer::create_single_gate_GraphXfer(Context *union_ctx, Command src_cmd,
std::vector<Command> dst_cmds) {
// Currently only support source command with no constant parameters
// Assume the only added parameters are constant parameters
// Assume the number of non constant parameters are equal
const std::vector<Command> &dst_cmds) {
GateType src_tp = src_cmd.get_gate_type();
GraphXfer *graphXfer = new GraphXfer(union_ctx);

Gate *gate = union_ctx->get_gate(src_tp);
auto num_qubit = gate->get_num_qubits();
auto num_non_constant_params = gate->get_num_parameters();
auto num_qubits = gate->get_num_qubits();

OpX *src_op = new OpX(src_tp);
std::map<int, TensorX> dst_qubits_2_tensorx;
std::map<int, TensorX> dst_params_2_tensorx;

for (int i = 0; i < num_qubit; ++i) {
for (int i = 0; i < num_qubits; ++i) {
TensorX qubit_tensor = graphXfer->new_tensor();
src_op->add_input(qubit_tensor);
dst_qubits_2_tensorx[i] = qubit_tensor;
}

for (int i = 0; i < num_non_constant_params; ++i) {
for (int i = 0; i < gate->get_num_parameters(); ++i) {
TensorX param_tensor = graphXfer->new_tensor();
src_op->add_input(param_tensor);
dst_params_2_tensorx[i] = param_tensor;
if (src_cmd.param_idx[i] != -1) {
dst_params_2_tensorx[src_cmd.param_idx[i]] = param_tensor;
} else {
graphXfer->paramValues[param_tensor.idx] = src_cmd.constant_params[i];
}
}

for (int i = 0; i < num_qubit; ++i) {
for (int i = 0; i < num_qubits; ++i) {
TensorX tensor(src_op, i);
src_op->add_output(tensor);
}
Expand Down Expand Up @@ -450,7 +448,7 @@ GraphXfer::create_single_gate_GraphXfer(Context *union_ctx, Command src_cmd,
}
graphXfer->dstOps.push_back(op);
}
for (int i = 0; i < num_qubit; ++i) {
for (int i = 0; i < num_qubits; ++i) {
graphXfer->map_output(src_op->outputs[i],
dst_qubits_2_tensorx[src_cmd.qubit_idx[i]]);
}
Expand All @@ -460,54 +458,42 @@ GraphXfer::create_single_gate_GraphXfer(Context *union_ctx, Command src_cmd,
std::pair<GraphXfer *, GraphXfer *> GraphXfer::ccz_cx_rz_xfer(Context *ctx) {
Context dst_ctx({GateType::rz, GateType::cx, GateType::input_qubit,
GateType::input_param});
std::pair<RuleParser *, RuleParser *> toffoli_rules =
RuleParser::ccz_cx_rz_rules();
std::vector<Command> cmds;
Command cmd;
toffoli_rules.first->find_convert_commands(&dst_ctx, GateType::ccz, cmd,
cmds);
GraphXfer *xfer_0 = create_single_gate_GraphXfer(ctx, cmd, cmds);
toffoli_rules.second->find_convert_commands(&dst_ctx, GateType::ccz, cmd,
cmds);
GraphXfer *xfer_1 = create_single_gate_GraphXfer(ctx, cmd, cmds);
delete toffoli_rules.first;
delete toffoli_rules.second;
auto toffoli_rules = RuleParser::ccz_cx_rz_rules();
std::vector<std::vector<Command>> cmds;
std::vector<Command> cmd;
auto num_xfers =
toffoli_rules.find_convert_commands(&dst_ctx, GateType::ccz, cmd, cmds);
assert(num_xfers == 2);
GraphXfer *xfer_0 = create_single_gate_GraphXfer(ctx, cmd[0], cmds[0]);
GraphXfer *xfer_1 = create_single_gate_GraphXfer(ctx, cmd[1], cmds[1]);
return std::make_pair(xfer_0, xfer_1);
}

std::pair<GraphXfer *, GraphXfer *> GraphXfer::ccz_cx_u1_xfer(Context *ctx) {
Context dst_ctx({GateType::u1, GateType::cx, GateType::input_qubit,
GateType::input_param});
std::pair<RuleParser *, RuleParser *> toffoli_rules =
RuleParser::ccz_cx_u1_rules();
std::vector<Command> cmds;
Command cmd;
toffoli_rules.first->find_convert_commands(&dst_ctx, GateType::ccz, cmd,
cmds);
GraphXfer *xfer_0 = create_single_gate_GraphXfer(ctx, cmd, cmds);
toffoli_rules.second->find_convert_commands(&dst_ctx, GateType::ccz, cmd,
cmds);
GraphXfer *xfer_1 = create_single_gate_GraphXfer(ctx, cmd, cmds);
delete toffoli_rules.first;
delete toffoli_rules.second;
auto toffoli_rules = RuleParser::ccz_cx_u1_rules();
std::vector<std::vector<Command>> cmds;
std::vector<Command> cmd;
auto num_xfers =
toffoli_rules.find_convert_commands(&dst_ctx, GateType::ccz, cmd, cmds);
assert(num_xfers == 2);
GraphXfer *xfer_0 = create_single_gate_GraphXfer(ctx, cmd[0], cmds[0]);
GraphXfer *xfer_1 = create_single_gate_GraphXfer(ctx, cmd[1], cmds[1]);
return std::make_pair(xfer_0, xfer_1);
}

std::pair<GraphXfer *, GraphXfer *> GraphXfer::ccz_cx_t_xfer(Context *ctx) {
Context dst_ctx({GateType::t, GateType::tdg, GateType::cx,
GateType::input_qubit, GateType::input_param});
std::pair<RuleParser *, RuleParser *> toffoli_rules =
RuleParser::ccz_cx_t_rules();
std::vector<Command> cmds;
Command cmd;
toffoli_rules.first->find_convert_commands(&dst_ctx, GateType::ccz, cmd,
cmds);
GraphXfer *xfer_0 = create_single_gate_GraphXfer(ctx, cmd, cmds);
toffoli_rules.second->find_convert_commands(&dst_ctx, GateType::ccz, cmd,
cmds);
GraphXfer *xfer_1 = create_single_gate_GraphXfer(ctx, cmd, cmds);
delete toffoli_rules.first;
delete toffoli_rules.second;
auto toffoli_rules = RuleParser::ccz_cx_t_rules();
std::vector<std::vector<Command>> cmds;
std::vector<Command> cmd;
auto num_xfers =
toffoli_rules.find_convert_commands(&dst_ctx, GateType::ccz, cmd, cmds);
assert(num_xfers == 2);
GraphXfer *xfer_0 = create_single_gate_GraphXfer(ctx, cmd[0], cmds[0]);
GraphXfer *xfer_1 = create_single_gate_GraphXfer(ctx, cmd[1], cmds[1]);
return std::make_pair(xfer_0, xfer_1);
}

Expand Down
6 changes: 3 additions & 3 deletions src/quartz/tasograph/substitution.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ class GraphXfer {
static GraphXfer *create_GraphXfer_from_qasm_str(Context *_context,
const std::string &src_str,
const std::string &dst_str);
static GraphXfer *create_single_gate_GraphXfer(Context *union_ctx,
Command src_cmd,
std::vector<Command> dst_cmds);
static GraphXfer *
create_single_gate_GraphXfer(Context *union_ctx, Command src_cmd,
const std::vector<Command> &dst_cmds);
static std::pair<GraphXfer *, GraphXfer *> ccz_cx_rz_xfer(Context *ctx);
static std::pair<GraphXfer *, GraphXfer *> ccz_cx_u1_xfer(Context *ctx);
static std::pair<GraphXfer *, GraphXfer *> ccz_cx_t_xfer(Context *ctx);
Expand Down
23 changes: 12 additions & 11 deletions src/quartz/tasograph/tasograph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,25 +491,26 @@ std::shared_ptr<Graph> Graph::context_shift(Context *src_ctx, Context *dst_ctx,
auto src_gates = src_ctx->get_supported_gates();
auto dst_gate_set = std::set<GateType>(dst_ctx->get_supported_gates().begin(),
dst_ctx->get_supported_gates().end());
std::map<GateType, GraphXfer *> tp_2_xfer;
std::vector<GraphXfer *> xfers;
for (auto gate_tp : src_gates) {
if (ignore_toffoli && src_ctx->get_gate(gate_tp)->is_toffoli_gate())
continue;
if (dst_gate_set.find(gate_tp) == dst_gate_set.end()) {
std::vector<Command> cmds;
Command src_cmd;
assert(
rule_parser->find_convert_commands(dst_ctx, gate_tp, src_cmd, cmds));

tp_2_xfer[gate_tp] =
GraphXfer::create_single_gate_GraphXfer(union_ctx, src_cmd, cmds);
std::vector<std::vector<Command>> cmds;
std::vector<Command> src_cmd;
int num_xfers =
rule_parser->find_convert_commands(dst_ctx, gate_tp, src_cmd, cmds);
assert(num_xfers > 0);
for (int i = 0; i < num_xfers; i++) {
xfers.push_back(GraphXfer::create_single_gate_GraphXfer(
union_ctx, src_cmd[i], cmds[i]));
}
}
}
std::shared_ptr<Graph> src_graph(new Graph(*this));
std::shared_ptr<Graph> dst_graph(nullptr);
for (auto it = tp_2_xfer.begin(); it != tp_2_xfer.end(); ++it) {
while ((dst_graph = it->second->run_1_time(0, src_graph.get())) !=
nullptr) {
for (auto &xfer : xfers) {
while ((dst_graph = xfer->run_1_time(0, src_graph.get())) != nullptr) {
src_graph = dst_graph;
}
}
Expand Down
Loading
Loading