Skip to content

[Context] Halved parameter flag. #191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/quartz/context/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Context::Context(const std::vector<GateType> &supported_gates,
ParamInfo *param_info)
: global_unique_id(16384), supported_gates_(supported_gates),
param_info_(param_info) {
// Precomputes the supported gates.
gates_.reserve(supported_gates.size());
for (const auto &gate : supported_gates) {
insert_gate(gate);
Expand All @@ -24,6 +25,16 @@ Context::Context(const std::vector<GateType> &supported_gates,
supported_quantum_gates_.emplace_back(gate);
}
}

// Precomputes whether any of the gates use halved parameters.
may_use_halved_params_ = false;
for (const auto &g : gates_) {
for (int i = 0; i < g.second->get_num_parameters(); ++i) {
if (g.second->is_param_halved(i)) {
may_use_halved_params_ = true;
}
}
}
}

Context::Context(const std::vector<GateType> &supported_gates, int num_qubits,
Expand Down Expand Up @@ -206,7 +217,9 @@ int Context::get_new_param_id(const ParamType &param) {
return param_info_->get_new_param_id(param);
}

int Context::get_new_param_id() { return param_info_->get_new_param_id(); }
int Context::get_new_param_id() {
return param_info_->get_new_param_id(may_use_halved_params_);
}

int Context::get_new_param_expression_id(
const std::vector<int> &parameter_indices, Gate *op) {
Expand All @@ -233,6 +246,10 @@ bool Context::param_is_expression(int id) const {
return param_info_->param_is_expression(id);
}

bool Context::param_is_halved(int id) const {
return param_info_->param_is_halved(id);
}

CircuitWire *Context::get_param_wire(int id) const {
return param_info_->get_param_wire(id);
}
Expand Down
2 changes: 2 additions & 0 deletions src/quartz/context/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class Context {
[[nodiscard]] bool param_is_symbolic(int id) const;
[[nodiscard]] bool param_has_value(int id) const;
[[nodiscard]] bool param_is_expression(int id) const;
[[nodiscard]] bool param_is_halved(int id) const;

[[nodiscard]] CircuitWire *get_param_wire(int id) const;

Expand Down Expand Up @@ -163,6 +164,7 @@ class Context {
bool insert_gate(GateType tp);

size_t global_unique_id;
bool may_use_halved_params_;
std::unordered_map<GateType, std::unique_ptr<Gate>> gates_;
std::unordered_map<
GateType, std::unordered_map<std::vector<bool>, std::unique_ptr<Gate>>>
Expand Down
15 changes: 12 additions & 3 deletions src/quartz/context/param_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ bool is_symbolic_constant(Gate *op) {
return rv;
}

ParamInfo::ParamInfo(int num_input_symbolic_params) {
ParamInfo::ParamInfo(int num_input_symbolic_params, bool is_halved) {
gen_random_parameters(num_input_symbolic_params);
for (int i = 0; i < num_input_symbolic_params; i++) {
get_new_param_id();
get_new_param_id(is_halved);
}
}

Expand Down Expand Up @@ -56,7 +56,9 @@ std::vector<ParamType> ParamInfo::get_all_input_param_values() const {

int ParamInfo::get_new_param_id(const ParamType &param) {
int id = (int)is_parameter_symbolic_.size();
assert(id == (int)is_parameter_halved_.size());
is_parameter_symbolic_.push_back(false);
is_parameter_halved_.push_back(false);
auto wire = std::make_unique<CircuitWire>();
wire->type = CircuitWire::input_param;
wire->index = id;
Expand All @@ -65,9 +67,10 @@ int ParamInfo::get_new_param_id(const ParamType &param) {
return id;
}

int ParamInfo::get_new_param_id() {
int ParamInfo::get_new_param_id(bool is_halved) {
int id = (int)is_parameter_symbolic_.size();
is_parameter_symbolic_.push_back(true);
is_parameter_halved_.push_back(is_halved);
// Make sure to generate a random parameter for each symbolic parameter.
gen_random_parameters(id + 1);
auto wire = std::make_unique<CircuitWire>();
Expand Down Expand Up @@ -98,6 +101,7 @@ int ParamInfo::get_new_param_expression_id(
}
int id = (int)is_parameter_symbolic_.size();
is_parameter_symbolic_.push_back(true);
is_parameter_halved_.push_back(false);
auto circuit_gate = std::make_unique<CircuitGate>();
circuit_gate->gate = op;
for (auto &input_id : parameter_indices) {
Expand Down Expand Up @@ -137,6 +141,11 @@ bool ParamInfo::param_is_expression(int id) const {
!parameter_wires_[id]->input_gates.empty();
}

bool ParamInfo::param_is_halved(int id) const {
return id >= 0 && id < (int)is_parameter_halved_.size() &&
is_parameter_halved_[id];
}

CircuitWire *ParamInfo::get_param_wire(int id) const {
if (id >= 0 && id < (int)parameter_wires_.size()) {
return parameter_wires_[id].get();
Expand Down
9 changes: 6 additions & 3 deletions src/quartz/context/param_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ class ParamInfo {
/**
* Default constructor: initialize 0 parameters.
*/
ParamInfo() : ParamInfo(0) {}
ParamInfo() : ParamInfo(0, false) {}
/**
* Initialize |num_input_symbolic_params| input symbolic parameters.
*/
explicit ParamInfo(int num_input_symbolic_params);
explicit ParamInfo(int num_input_symbolic_params, bool is_halved);

/**
* Generate random values for random testing for input symbolic parameters.
Expand Down Expand Up @@ -58,9 +58,10 @@ class ParamInfo {
int get_new_param_id(const ParamType &param);
/**
* Create a new symbolic parameter.
* @param is_halved If true, then used by a gate with period 4*pi.
* @return The index of the new symbolic parameter.
*/
int get_new_param_id();
int get_new_param_id(bool is_halved);
/**
* Create a new parameter expression. If all input parameters are concrete,
* compute the result directly instead of creating the expression.
Expand All @@ -75,6 +76,7 @@ class ParamInfo {
[[nodiscard]] bool param_is_symbolic(int id) const;
[[nodiscard]] bool param_has_value(int id) const;
[[nodiscard]] bool param_is_expression(int id) const;
[[nodiscard]] bool param_is_halved(int id) const;

[[nodiscard]] CircuitWire *get_param_wire(int id) const;

Expand All @@ -101,6 +103,7 @@ class ParamInfo {
std::vector<ParamType> parameter_values_;
std::vector<std::unique_ptr<CircuitWire>> parameter_wires_;
std::vector<bool> is_parameter_symbolic_;
std::vector<bool> is_parameter_halved_;
// A holder for parameter expressions.
std::vector<std::unique_ptr<CircuitGate>> parameter_expressions_;

Expand Down
2 changes: 2 additions & 0 deletions src/quartz/gate/gate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ bool Gate::is_sparse() const { return false; }

bool Gate::is_diagonal() const { return false; }

bool Gate::is_param_halved(int i) const { return false; }

int Gate::get_num_control_qubits() const { return 0; }

std::vector<bool> Gate::get_control_state() const {
Expand Down
7 changes: 7 additions & 0 deletions src/quartz/gate/gate.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ class Gate {
* Default value is false.
*/
[[nodiscard]] virtual bool is_diagonal() const;
/**
* @param i the index of the parameter to check
* @return True if this gate is parameterized and parameter i has a period of
* 4*pi as opposed to 2*pi (e.g., rx, ry, rz).
* Default value is false.
*/
[[nodiscard]] virtual bool is_param_halved(int i) const;
/**
* @return The number of control qubits for controlled gates; or 0 if it is
* not a controlled gate.
Expand Down
1 change: 1 addition & 0 deletions src/quartz/gate/rx.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class RXGate : public Gate {
}
return cached_matrices[theta].get();
}
bool is_param_halved(int i) const override { return true; }
std::unordered_map<float, std::unique_ptr<Matrix<2>>> cached_matrices;
};

Expand Down
1 change: 1 addition & 0 deletions src/quartz/gate/ry.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class RYGate : public Gate {
}
return cached_matrices[theta].get();
}
bool is_param_halved(int i) const override { return true; }
std::unordered_map<float, std::unique_ptr<Matrix<2>>> cached_matrices;
};

Expand Down
1 change: 1 addition & 0 deletions src/quartz/gate/rz.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class RZGate : public Gate {
return cached_matrices[theta].get();
}
bool is_sparse() const override { return true; }
bool is_param_halved(int i) const override { return true; }
std::unordered_map<float, std::unique_ptr<Matrix<2>>> cached_matrices;
};

Expand Down
1 change: 1 addition & 0 deletions src/quartz/gate/u3.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class U3Gate : public Gate {
}
return cached_matrices[theta][phi][lambda].get();
}
bool is_param_halved(int i) const override { return i == 0; }
std::unordered_map<
float, std::unordered_map<
float, std::unordered_map<float, std::unique_ptr<Matrix<2>>>>>
Expand Down
2 changes: 1 addition & 1 deletion src/test/gen_ecc_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ void gen_ecc_set(const std::vector<GateType> &supported_gates,
const std::string &file_prefix, bool unique_parameters,
bool generate_representative_set, int num_qubits,
int num_input_parameters, int max_num_quantum_gates) {
ParamInfo param_info(/*num_input_symbolic_params=*/num_input_parameters);
ParamInfo param_info(num_input_parameters, false);
Context ctx(supported_gates, num_qubits, &param_info);
Generator gen(&ctx);

Expand Down
2 changes: 1 addition & 1 deletion src/test/test_all.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using namespace quartz;

int main() {
std::cout << "Hello, World!" << std::endl;
ParamInfo param_info(/*num_input_symbolic_params=*/2);
ParamInfo param_info(/*num_input_symbolic_params=*/2, false);
Context ctx({GateType::x, GateType::y, GateType::add, GateType::neg,
GateType::u2, GateType::u3, GateType::cx},
2, &param_info);
Expand Down
2 changes: 1 addition & 1 deletion src/test/test_bfs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ int main() {
const bool run_bfs_unverified = false;
const bool run_bfs_verified = true; // with representative pruning

ParamInfo param_info(/*num_input_symbolic_params=*/num_input_parameters);
ParamInfo param_info(num_input_parameters, false);
Context ctx({GateType::h}, num_qubits, &param_info);
Generator gen(&ctx);

Expand Down
2 changes: 1 addition & 1 deletion src/test/test_constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
using namespace quartz;

int main() {
ParamInfo param_info(0);
ParamInfo param_info;
Context ctx({GateType::rx}, 2, &param_info);

QASMParser parser(&ctx);
Expand Down
2 changes: 1 addition & 1 deletion src/test/test_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
using namespace quartz;

int main() {
ParamInfo param_info(/*num_input_symbolic_params=*/3);
ParamInfo param_info(/*num_input_symbolic_params=*/3, false);
Context ctx({GateType::x, GateType::y, GateType::cx, GateType::h}, 3,
&param_info);
Generator gen(&ctx);
Expand Down
2 changes: 1 addition & 1 deletion src/test/test_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ void test_generator(const std::vector<GateType> &support_gates, int num_qubits,
int max_num_input_parameters, int max_num_gates,
bool verbose, const std::string &save_file_name,
bool count_minimal_representations = false) {
ParamInfo param_info(/*num_input_symbolic_params=*/max_num_input_parameters);
ParamInfo param_info(max_num_input_parameters, false);
Context ctx(support_gates, num_qubits, &param_info);
Generator generator(&ctx);
Dataset dataset;
Expand Down
2 changes: 1 addition & 1 deletion src/test/test_mult.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
using namespace quartz;

int main() {
ParamInfo param_info(0);
ParamInfo param_info;
Context ctx({GateType::rx, GateType::mult}, 1, &param_info);

auto p0 = ctx.get_new_param_id(2.0);
Expand Down
2 changes: 1 addition & 1 deletion src/test/test_optimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
using namespace quartz;

int main() {
ParamInfo param_info(/*num_input_symbolic_params=*/2);
ParamInfo param_info(/*num_input_symbolic_params=*/2, false);
Context ctx({GateType::input_qubit, GateType::input_param, GateType::cx,
GateType::h, GateType::rz, GateType::x, GateType::add},
/*num_qubits=*/3, &param_info);
Expand Down
2 changes: 1 addition & 1 deletion src/test/test_phase_shift.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ int main() {
const int num_input_parameters = 1;
const int max_num_gates = 2;
const int max_num_param_gates = 1;
ParamInfo param_info(/*num_input_symbolic_params=*/num_input_parameters);
ParamInfo param_info(num_input_parameters, false);
Context ctx({GateType::rz, GateType::u1, GateType::add}, num_qubits,
&param_info);

Expand Down
2 changes: 1 addition & 1 deletion src/test/test_pi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
using namespace quartz;

int main() {
ParamInfo param_info(0);
ParamInfo param_info;
Context ctx({GateType::rx, GateType::pi}, 1, &param_info);

auto p0 = ctx.get_new_param_id(PI / 2);
Expand Down
2 changes: 1 addition & 1 deletion src/test/test_pruning.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ void test_pruning(
int max_num_param_gates = 1, bool run_representative_pruning = true,
bool run_original = true, bool run_original_unverified = false,
bool run_original_verified = true, bool unique_parameters = false) {
ParamInfo param_info(/*num_input_symbolic_params=*/num_input_parameters);
ParamInfo param_info(num_input_parameters, false);
Context ctx(supported_gates, num_qubits, &param_info);
Generator gen(&ctx);

Expand Down
Loading
Loading