Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Manage parameters in ParamInfo and fix a bug causing Graph with parameter expressions to crash #166

Merged
merged 9 commits into from
Mar 2, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Clean up context.cpp
xumingkuan committed Mar 1, 2024
commit 7604e0f156a9a93f6f69766ed9b3e62885031949
56 changes: 9 additions & 47 deletions src/quartz/context/context.cpp
Original file line number Diff line number Diff line change
@@ -5,10 +5,10 @@
#include "quartz/utils/string_utils.h"

#include <cassert>
#include <cmath>
#include <iostream>
#include <mutex>
#include <random>
#include <set>

namespace quartz {
Context::Context(const std::vector<GateType> &supported_gates,
@@ -101,10 +101,6 @@ bool Context::insert_gate(GateType tp) {
return true;
}

void Context::gen_random_parameters(const int num_params) {
param_info_->gen_random_parameters(num_params);
}

const std::vector<GateType> &Context::get_supported_gates() const {
return supported_gates_;
}
@@ -203,16 +199,14 @@ void Context::set_param_value(int id, const ParamType &param) {
}

std::vector<ParamType> Context::get_all_input_param_values() const {
return param_info_->parameter_values_;
return param_info_->get_all_input_param_values();
}

int Context::get_new_param_id(const ParamType &param) {
return param_info_->get_new_param_id(param);
}

int Context::get_new_param_id() {
return param_info_->get_new_param_id();
}
int Context::get_new_param_id() { return param_info_->get_new_param_id(); }

int Context::get_new_param_expression_id(
const std::vector<int> &parameter_indices, Gate *op) {
@@ -228,42 +222,24 @@ int Context::get_num_input_symbolic_parameters() const {
}

bool Context::param_is_symbolic(int id) const {
return id >= 0 && id < (int)param_info_->is_parameter_symbolic_.size() &&
param_info_->is_parameter_symbolic_[id];
return param_info_->param_is_symbolic(id);
}

bool Context::param_has_value(int id) const {
return id >= 0 && id < (int)param_info_->is_parameter_symbolic_.size() &&
!param_info_->is_parameter_symbolic_[id];
return param_info_->param_has_value(id);
}

bool Context::param_is_expression(int id) const {
return id >= 0 && id < (int)param_info_->parameter_wires_.size() &&
!param_info_->parameter_wires_[id]->input_gates.empty();
return param_info_->param_is_expression(id);
}

CircuitWire *Context::get_param_wire(int id) const {
if (id >= 0 && id < (int)param_info_->parameter_wires_.size()) {
return param_info_->parameter_wires_[id].get();
} else {
return nullptr; // out of range
}
return param_info_->get_param_wire(id);
}

std::vector<ParamType>
Context::compute_parameters(const std::vector<ParamType> &input_parameters) {
auto result = input_parameters;
result.resize(param_info_->is_parameter_symbolic_.size());
for (auto &expr : param_info_->parameter_expressions_) {
std::vector<ParamType> params;
for (const auto &input_wire : expr->input_wires) {
params.push_back(result[input_wire->index]);
}
assert(expr->output_wires.size() == 1);
const auto &output_wire = expr->output_wires[0];
result[output_wire->index] = expr->gate->compute(params);
}
return result;
return param_info_->compute_parameters(input_parameters);
}

std::vector<int> Context::get_param_permutation(
@@ -353,21 +329,7 @@ void Context::generate_parameter_expressions(
}

std::vector<InputParamMaskType> Context::get_param_masks() const {
std::vector<InputParamMaskType> param_mask(
param_info_->is_parameter_symbolic_.size());
for (int i = 0; i < (int)param_mask.size(); i++) {
if (!param_is_expression(i)) {
param_mask[i] = ((InputParamMaskType)1) << i;
}
}
for (auto &expr : param_info_->parameter_expressions_) {
const auto &output_wire = expr->output_wires[0];
param_mask[output_wire->index] = 0;
for (const auto &input_wire : expr->input_wires) {
param_mask[output_wire->index] |= param_mask[input_wire->index];
}
}
return param_mask;
return param_info_->get_param_masks();
}

std::string Context::param_info_to_json() const {
10 changes: 0 additions & 10 deletions src/quartz/context/context.h
Original file line number Diff line number Diff line change
@@ -9,7 +9,6 @@
#include <algorithm>
#include <memory>
#include <random>
#include <set>
#include <unordered_map>
#include <vector>

@@ -163,15 +162,6 @@ class Context {
private:
bool insert_gate(GateType tp);

/**
* Generate random values for random testing for input symbolic parameters.
* The results are stored in |random_parameters_|.
* The size of |random_parameters_| should be equal to the number of input
* symbolic parameters.
* @param num_params The number of input symbolic parameters.
*/
void gen_random_parameters(int num_params);

size_t global_unique_id;
std::unordered_map<GateType, std::unique_ptr<Gate>> gates_;
std::unordered_map<
13 changes: 8 additions & 5 deletions src/quartz/context/param_info.h
Original file line number Diff line number Diff line change
@@ -3,14 +3,9 @@
#include "quartz/circuitseq/circuitgate.h"
#include "quartz/circuitseq/circuitwire.h"
#include "quartz/gate/gate_utils.h"
#include "quartz/math/vector.h"
#include "quartz/utils/utils.h"

#include <algorithm>
#include <memory>
#include <random>
#include <set>
#include <unordered_map>
#include <vector>

namespace quartz {
@@ -26,7 +21,15 @@ class ParamInfo {
*/
explicit ParamInfo(int num_input_symbolic_params);

/**
* Generate random values for random testing for input symbolic parameters.
* The results are stored in |random_parameters_|.
* The size of |random_parameters_| should be equal to the number of input
* symbolic parameters.
* @param num_params The number of input symbolic parameters.
*/
void gen_random_parameters(int num_params);

[[nodiscard]] std::vector<ParamType> get_all_generated_parameters() const;

/**