Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parser support for halved parameters. #193

Merged
merged 5 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/quartz/circuitseq/circuitgate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ std::string CircuitGate::to_qasm_style_string(Context *ctx,
int param_precision) const {
assert(gate->is_quantum_gate());
std::string result;

// Prints entry to control block, if gate is controlled.
if (gate->get_num_control_qubits() > 0) {
auto control_state = gate->get_control_state();
if (!std::all_of(control_state.begin(), control_state.end(),
Expand All @@ -249,34 +251,48 @@ std::string CircuitGate::to_qasm_style_string(Context *ctx,
}
}

// Prints gate name.
auto gate_name = gate_type_name(gate->tp);
std::transform(gate_name.begin(), gate_name.end(), gate_name.begin(),
[](unsigned char c) { return std::tolower(c); });
result += gate_name;

// Prints parameters.
if (gate->get_num_parameters() > 0) {
int num_remaining_parameters = gate->get_num_parameters();
int curr_param_index = 0;
result += "(";
for (auto input_wire : input_wires) {
if (input_wire->is_parameter()) {
// Ensures the wire is valid.
assert(ctx->param_has_value(input_wire->index));

// Determines the parameter value with respect to reparameterization.
std::ostringstream out;
out.precision(param_precision);
const auto &param_value = ctx->get_param_value(input_wire->index);
if (param_value == 0) {
// optimization: if a parameter is 0, do not output that many digits
out << "0";
} else if (gate->is_param_halved(curr_param_index)) {
out << std::fixed << 2 * param_value;
} else {
out << std::fixed << param_value;
}
result += std::move(out).str();

// Prepares for printing the next parameter.
num_remaining_parameters--;
curr_param_index++;
if (num_remaining_parameters != 0) {
result += ",";
}
}
}
result += ")";
}

// Prints target qubits.
result += " ";
bool first_qubit = true;
for (auto input_wire : input_wires) {
Expand All @@ -291,6 +307,7 @@ std::string CircuitGate::to_qasm_style_string(Context *ctx,
}
result += ";\n";

// Prints exit from control block, if gate is controlled.
if (gate->get_num_control_qubits() > 0) {
auto control_state = gate->get_control_state();
if (!std::all_of(control_state.begin(), control_state.end(),
Expand Down
11 changes: 11 additions & 0 deletions src/quartz/circuitseq/circuitseq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -994,11 +994,22 @@ CircuitSeq::from_qasm_style_string(Context *ctx, const std::string &str) {

std::string CircuitSeq::to_qasm_style_string(Context *ctx,
int param_precision) const {
// Checks if parameters are in use.
for (auto param : get_input_param_indices(ctx)) {
if (ctx->param_is_symbolic(param)) {
std::cerr << "to_qasm_style_string only supports consts." << std::endl;
break;
}
}

// Generates header.
std::string result = "OPENQASM 2.0;\n"
"include \"qelib1.inc\";\n"
"qreg q[";
result += std::to_string(get_num_qubits());
result += "];\n";

// Populates gate list.
for (auto &gate : gates) {
result += gate->to_qasm_style_string(ctx, param_precision);
}
Expand Down
105 changes: 78 additions & 27 deletions src/quartz/parser/qasm_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,44 +66,56 @@ std::string strip(const std::string &input) {
return std::string(st, ed.base());
}

int ParamParser::parse_number(bool negative, ParamType p) {
int ParamParser::parse_number(bool negative, ParamType p, bool is_halved) {
// Handles negative constants.
if (negative) {
p = -p;
}

// Handles halved parameters.
if (is_halved) {
p = p / 2;
}

// Constructs the constant parameter if it does not already exist.
if (number_params_.count(p) == 0) {
int param_id = ctx_->get_new_param_id(p);
number_params_[p] = param_id;
}

return number_params_[p];
}

int ParamParser::parse_pi_term(bool negative, ParamType n, ParamType d) {
int ParamParser::parse_pi_term(bool negative, ParamType n, ParamType d,
bool is_halved) {
// If pi is not symbolic, then falls back to constants.
if (!symbolic_pi_) {
return parse_number(negative, n * PI / d);
return parse_number(negative, n * PI / d, is_halved);
}

// Handles negative coefficients.
if (negative) {
n = -n;
}

// Handles halved parameters.
if (is_halved) {
d = d * 2;
}

// Constructs the pi expression, if it does not already exist.
if (pi_params_[n].count(d) == 0) {
// Checks if fraction of pi already exists.
// If (n == 1) then this will cache the final expression.
if (pi_params_[1].count(d) == 0) {
int id = parse_number(false, d);
int id = parse_number(false, d, false);
auto gate = ctx_->get_gate(GateType::pi);
pi_params_[1][d] = ctx_->get_new_param_expression_id({id}, gate);
}

// Scales the fraction of pi when the numerator is not equal to 1.
// If (n != 1), then this will cache the final expression.
if (n != 1) {
int nid = parse_number(false, n);
int nid = parse_number(false, n, false);
int pid = pi_params_[1][d];
auto gate = ctx_->get_gate(GateType::mult);
pi_params_[n][d] = ctx_->get_new_param_expression_id({nid, pid}, gate);
Expand All @@ -114,6 +126,50 @@ int ParamParser::parse_pi_term(bool negative, ParamType n, ParamType d) {
return pi_params_[n][d];
}

int ParamParser::parse_symb_param(bool negative, std::string name, int i,
bool is_halved) {
// Attempts to look up the symbolic parameter identifier.
if (symb_params_[name].count(i) == 0) {
std::cerr << "Invalid parameter reference: " << name << "[" << i << "]"
<< std::endl;
assert(false);
return -1;
}
int param = symb_params_[name][i];

// Ensures that halved parameter requirements are obeyed.
if (is_halved) {
if (!ctx_->param_is_halved(param)) {
std::cerr << "Halved gate requires halved parameters." << std::endl;
assert(false);
return -1;
}
}

// Rescales halved parameters for gates which are not halved.
if (!is_halved) {
if (ctx_->param_is_halved(param)) {
if (sum_params_[param].count(param) == 0) {
auto add = ctx_->get_gate(GateType::add);
int dbl_id = ctx_->get_new_param_expression_id({param, param}, add);
sum_params_[param][param] = dbl_id;
}
param = sum_params_[param][param];
}
}

// Handles negative parameters.
if (negative) {
if (negative_symb_params.count(param) == 0) {
auto neg = ctx_->get_gate(GateType::neg);
int neg_id = ctx_->get_new_param_expression_id({param}, neg);
negative_symb_params[param] = neg_id;
}
param = negative_symb_params[param];
}
return param;
}

bool ParamParser::parse_array_decl(std::stringstream &ss) {
// The first two tokens of the stream should be '[angle len]'. Recall that the
// comma between angle and len has been replaced by a space.
Expand Down Expand Up @@ -168,7 +224,7 @@ bool ParamParser::parse_array_decl(std::stringstream &ss) {
return true;
}

int ParamParser::parse_expr(std::stringstream &ss) {
int ParamParser::parse_expr(std::stringstream &ss, bool is_halved) {
// Extracts the parameter expression from the string stream.
std::string token;
ss >> token;
Expand Down Expand Up @@ -199,7 +255,7 @@ int ParamParser::parse_expr(std::stringstream &ss) {
int tid;
if (pos == std::string::npos) {
// Case: t, -t
tid = parse_term(neg_prefix, token);
tid = parse_term(neg_prefix, token, is_halved);
token = "";
} else if (pos > 0) {
// Case: t+e, t-e
Expand All @@ -211,7 +267,7 @@ int ParamParser::parse_expr(std::stringstream &ss) {

// Parses the right-hand side as a token.
// The substraction is absorbed by this term as a negative sign.
tid = parse_term(is_minus, term);
tid = parse_term(is_minus, term, is_halved);
} else {
std::cerr << "Unexpected (+) or (-) at index 0: " << token << std::endl;
assert(false);
Expand Down Expand Up @@ -239,7 +295,7 @@ int ParamParser::parse_expr(std::stringstream &ss) {
return id;
}

int ParamParser::parse_term(bool negative, std::string token) {
int ParamParser::parse_term(bool negative, std::string token, bool is_halved) {
// Identifies the format case matching this token.
if (token.find("[") != std::string::npos) {
// Case: name[i]
Expand All @@ -255,31 +311,26 @@ int ParamParser::parse_term(bool negative, std::string token) {
// Determines the reference index.
int idx = string_to_number(istr);
if (idx == -1) {
std::cerr << "Invalid parameter reference index: " << istr << std::endl;
std::cerr << "Negative parameter reference index: " << istr << std::endl;
assert(false);
return false;
return -1;
}

// Attempts to look up the symbolic parameter identifier.
if (symb_params_[name].count(idx) == 0) {
std::cerr << "Invalid parameter reference: " << token << std::endl;
assert(false);
return false;
}
return symb_params_[name][idx];
// Resolves symbolic parameter.
return parse_symb_param(negative, name, idx, is_halved);
} else if (token.find("pi") == 0) {
if (token == "pi") {
// Case: pi
return parse_pi_term(negative, 1.0, 1.0);
return parse_pi_term(negative, 1.0, 1.0, is_halved);
} else {
// Cases: pi*0.123 or pi/2
auto d = token.substr(3, std::string::npos);
if (token[2] == '*') {
// Case: pi*0.123
return parse_pi_term(negative, std::stod(d), 1.0);
return parse_pi_term(negative, std::stod(d), 1.0, is_halved);
} else if (token[2] == '/') {
// Case: pi/2
return parse_pi_term(negative, 1.0, std::stod(d));
return parse_pi_term(negative, 1.0, std::stod(d), is_halved);
} else {
std::cerr << "Unsupported parameter format: " << token << std::endl;
assert(false);
Expand All @@ -295,7 +346,7 @@ int ParamParser::parse_term(bool negative, std::string token) {
ParamType p = std::stod(token.substr(0, token.find('/')));
p /= PI;
p /= std::stod(token.substr(lparen_pos + 1, mult_pos - lparen_pos - 1));
return parse_number(negative, p);
return parse_number(negative, p, is_halved);
} else {
// Case: 0.123*pi or 0.123*pi/2
auto d = token.substr(0, token.find('*'));
Expand All @@ -305,11 +356,11 @@ int ParamParser::parse_term(bool negative, std::string token) {
// Case: 0.123*pi/2
denom = std::stod(token.substr(token.find('/') + 1));
}
return parse_pi_term(negative, num, denom);
return parse_pi_term(negative, num, denom, is_halved);
}
} else {
// Case: 0.123
return parse_number(negative, std::stod(token));
return parse_number(negative, std::stod(token), is_halved);
}

// This line should be unreachable.
Expand Down Expand Up @@ -382,7 +433,7 @@ int QubitParser::parse_access(std::stringstream &ss) {
if (!finalized_) {
std::cerr << "Can only access qubits after finalization." << std::endl;
assert(false);
return false;
return -1;
}

// Gets qreg array name.
Expand Down Expand Up @@ -415,7 +466,7 @@ int QubitParser::finalize() {
if (finalized_) {
std::cerr << "Can only finalize qreg lookup once." << std::endl;
assert(false);
return false;
return -1;
}
finalized_ = true;

Expand Down
Loading
Loading