Skip to content

Commit 67c7f9d

Browse files
authored
[Verifier] Distinguish parameters only for arithmetic expressions from concrete parameters which uses trigonometric functions (#204)
* [Verifier] Distinguish parameters only for arithmetic expressions from concrete parameters which uses trigonometric functions * code format * Parse int/float when reading param info
1 parent 90940fd commit 67c7f9d

File tree

10 files changed

+210
-79
lines changed

10 files changed

+210
-79
lines changed

src/python/verifier/verifier.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,9 @@ def compute_params(param_info):
601601
for i in range(len(param_info)):
602602
if param_info[i] == "": # symbolic
603603
params.append(symbolic_params.pop(0))
604-
elif isinstance(param_info[i], (int, float)): # concrete
604+
elif isinstance(param_info[i], int): # concrete parameter for calculation
605+
params.append(param_info[i])
606+
elif isinstance(param_info[i], float): # concrete parameter to be directly used
605607
params.append((math.cos(param_info[i]), math.sin(param_info[i])))
606608
else: # expression
607609
op = param_info[i][0]

src/quartz/circuitseq/circuitgate.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ std::string CircuitGate::to_qasm_style_string(Context *ctx,
265265
for (auto input_wire : input_wires) {
266266
if (input_wire->is_parameter()) {
267267
// Ensures the wire is valid.
268-
assert(ctx->param_has_value(input_wire->index));
268+
assert(ctx->param_is_const(input_wire->index));
269269

270270
// Determines the parameter value with respect to reparameterization.
271271
std::ostringstream out;

src/quartz/circuitseq/circuitseq.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -996,7 +996,7 @@ std::string CircuitSeq::to_qasm_style_string(Context *ctx,
996996
int param_precision) const {
997997
// Checks if parameters are in use.
998998
for (auto param : get_input_param_indices(ctx)) {
999-
if (ctx->param_is_symbolic(param)) {
999+
if (!ctx->param_is_const(param)) {
10001000
std::cerr << "to_qasm_style_string only supports consts." << std::endl;
10011001
break;
10021002
}

src/quartz/context/context.cpp

+28-29
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,10 @@ int Context::get_new_param_id(const ParamType &param) {
217217
return param_info_->get_new_param_id(param);
218218
}
219219

220+
int Context::get_new_arithmetic_param_id(const ParamType &param) {
221+
return param_info_->get_new_arithmetic_param_id(param);
222+
}
223+
220224
int Context::get_new_param_id() {
221225
return param_info_->get_new_param_id(may_use_halved_params_);
222226
}
@@ -238,8 +242,8 @@ bool Context::param_is_symbolic(int id) const {
238242
return param_info_->param_is_symbolic(id);
239243
}
240244

241-
bool Context::param_has_value(int id) const {
242-
return param_info_->param_has_value(id);
245+
bool Context::param_is_const(int id) const {
246+
return param_info_->param_is_const(id);
243247
}
244248

245249
bool Context::param_is_expression(int id) const {
@@ -261,7 +265,7 @@ Context::compute_parameters(const std::vector<ParamType> &input_parameters) {
261265

262266
std::vector<int> Context::get_param_permutation(
263267
const std::vector<int> &input_param_permutation) {
264-
int num_parameters = (int)param_info_->is_parameter_symbolic_.size();
268+
int num_parameters = (int)param_info_->parameter_class_.size();
265269
std::vector<int> result = input_param_permutation;
266270
result.resize(num_parameters, -1); // fill with -1
267271
for (int i = (int)input_param_permutation.size(); i < num_parameters; i++) {
@@ -307,7 +311,7 @@ std::vector<int> Context::get_param_permutation(
307311
void Context::generate_parameter_expressions(
308312
int max_num_operators_per_expression) {
309313
assert(max_num_operators_per_expression == 1);
310-
int num_input_parameters = (int)param_info_->is_parameter_symbolic_.size();
314+
int num_input_parameters = (int)param_info_->parameter_class_.size();
311315
assert(num_input_parameters > 0);
312316
if (!param_info_->parameter_expressions_.empty()) {
313317
std::cerr << "Context::generate_parameter_expressions() called twice for a "
@@ -350,25 +354,7 @@ std::vector<InputParamMaskType> Context::get_param_masks() const {
350354
}
351355

352356
std::string Context::param_info_to_json() const {
353-
std::string result = "[";
354-
result += "[";
355-
result += std::to_string(param_info_->is_parameter_symbolic_.size());
356-
for (int i = 0; i < (int)param_info_->is_parameter_symbolic_.size(); i++) {
357-
result += ", ";
358-
if (param_is_expression(i)) {
359-
result += param_info_->parameter_wires_[i]->input_gates[0]->to_json();
360-
} else if (param_info_->is_parameter_symbolic_[i]) {
361-
result += "\"\"";
362-
} else {
363-
result += to_string_with_precision(param_info_->parameter_values_[i],
364-
/*precision=*/17);
365-
}
366-
}
367-
result += "], ";
368-
result += to_json_style_string_with_precision(param_info_->random_parameters_,
369-
/*precision=*/17);
370-
result += "]";
371-
return result;
357+
return param_info_->to_json();
372358
}
373359

374360
bool Context::load_param_info_from_json(std::istream &fin) {
@@ -386,8 +372,8 @@ bool Context::load_param_info_from_json(std::istream &fin) {
386372
}
387373
int num_params;
388374
fin >> num_params;
389-
param_info_->is_parameter_symbolic_.clear();
390-
param_info_->is_parameter_symbolic_.reserve(num_params);
375+
param_info_->parameter_class_.clear();
376+
param_info_->parameter_class_.reserve(num_params);
391377
param_info_->parameter_wires_.clear();
392378
param_info_->parameter_wires_.reserve(num_params);
393379
param_info_->parameter_values_.clear();
@@ -416,10 +402,23 @@ bool Context::load_param_info_from_json(std::istream &fin) {
416402
assert(id == i);
417403
} else {
418404
// concrete parameter
419-
fin.unget();
420-
ParamType val;
421-
fin >> val;
422-
int id = get_new_param_id(val);
405+
bool is_float = false;
406+
std::string s; // record the number string
407+
while (ch != ',' && ch != ']') {
408+
s += ch;
409+
if (ch == '.' || ch == 'e' || ch == 'E') {
410+
is_float = true;
411+
}
412+
fin >> ch;
413+
}
414+
fin.unget(); // put the ',' or ']' back
415+
ParamType val = std::stod(s);
416+
int id;
417+
if (is_float) {
418+
id = get_new_param_id(val);
419+
} else {
420+
id = get_new_arithmetic_param_id(val);
421+
}
423422
assert(id == i);
424423
}
425424
}

src/quartz/context/context.h

+7-1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ class Context {
8181
* @return The index of the new concrete parameter.
8282
*/
8383
int get_new_param_id(const ParamType &param);
84+
/**
85+
* Create a new concrete (integer) parameter that can only be used in
86+
* arithmetic expressions.
87+
* @return The index of the new concrete parameter.
88+
*/
89+
int get_new_arithmetic_param_id(const ParamType &param);
8490
/**
8591
* Create a new symbolic parameter.
8692
* @return The index of the new symbolic parameter.
@@ -98,7 +104,7 @@ class Context {
98104
[[nodiscard]] int get_num_parameters() const;
99105
[[nodiscard]] int get_num_input_symbolic_parameters() const;
100106
[[nodiscard]] bool param_is_symbolic(int id) const;
101-
[[nodiscard]] bool param_has_value(int id) const;
107+
[[nodiscard]] bool param_is_const(int id) const;
102108
[[nodiscard]] bool param_is_expression(int id) const;
103109
[[nodiscard]] bool param_is_halved(int id) const;
104110

src/quartz/context/param_info.cpp

+78-27
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "param_info.h"
22

3+
#include "quartz/utils/string_utils.h"
4+
35
#include <cassert>
46

57
namespace quartz {
@@ -36,14 +38,15 @@ std::vector<ParamType> ParamInfo::get_all_generated_parameters() const {
3638
}
3739

3840
ParamType ParamInfo::get_param_value(int id) const {
39-
assert(id >= 0 && id < (int)parameter_values_.size());
40-
assert(!is_parameter_symbolic_[id]);
41+
assert(id >= 0 && id < (int)parameter_class_.size());
42+
assert(parameter_class_[id].is_const());
4143
return parameter_values_[id];
4244
}
4345

4446
void ParamInfo::set_param_value(int id, const ParamType &param) {
45-
assert(id >= 0 && id < (int)is_parameter_symbolic_.size());
46-
assert(!is_parameter_symbolic_[id]);
47+
assert(id >= 0 && id < (int)parameter_class_.size());
48+
assert(parameter_class_[id] == ParamClass::concrete_const ||
49+
parameter_class_[id] == ParamClass::arithmetic_int);
4750
while (id >= (int)parameter_values_.size()) {
4851
parameter_values_.emplace_back();
4952
}
@@ -55,10 +58,22 @@ std::vector<ParamType> ParamInfo::get_all_input_param_values() const {
5558
}
5659

5760
int ParamInfo::get_new_param_id(const ParamType &param) {
58-
int id = (int)is_parameter_symbolic_.size();
59-
assert(id == (int)is_parameter_halved_.size());
60-
is_parameter_symbolic_.push_back(false);
61-
is_parameter_halved_.push_back(false);
61+
int id = (int)parameter_class_.size();
62+
assert(id == (int)parameter_class_.size());
63+
parameter_class_.emplace_back(ParamClass::concrete_const);
64+
auto wire = std::make_unique<CircuitWire>();
65+
wire->type = CircuitWire::input_param;
66+
wire->index = id;
67+
parameter_wires_.push_back(std::move(wire));
68+
set_param_value(id, param);
69+
return id;
70+
}
71+
72+
int ParamInfo::get_new_arithmetic_param_id(const ParamType &param) {
73+
int id = (int)parameter_class_.size();
74+
assert(id == (int)parameter_class_.size());
75+
assert((int)param == param);
76+
parameter_class_.emplace_back(ParamClass::arithmetic_int);
6277
auto wire = std::make_unique<CircuitWire>();
6378
wire->type = CircuitWire::input_param;
6479
wire->index = id;
@@ -68,9 +83,9 @@ int ParamInfo::get_new_param_id(const ParamType &param) {
6883
}
6984

7085
int ParamInfo::get_new_param_id(bool is_halved) {
71-
int id = (int)is_parameter_symbolic_.size();
72-
is_parameter_symbolic_.push_back(true);
73-
is_parameter_halved_.push_back(is_halved);
86+
int id = (int)parameter_class_.size();
87+
parameter_class_.emplace_back(is_halved ? ParamClass::symbolic_halved
88+
: ParamClass::symbolic);
7489
// Make sure to generate a random parameter for each symbolic parameter.
7590
gen_random_parameters(id + 1);
7691
auto wire = std::make_unique<CircuitWire>();
@@ -83,11 +98,15 @@ int ParamInfo::get_new_param_id(bool is_halved) {
8398
int ParamInfo::get_new_param_expression_id(
8499
const std::vector<int> &parameter_indices, Gate *op) {
85100
bool is_symbolic = is_symbolic_constant(op);
101+
bool is_const = true;
86102
for (auto &input_id : parameter_indices) {
87-
assert(input_id >= 0 && input_id < (int)is_parameter_symbolic_.size());
88-
if (param_is_symbolic(input_id)) {
103+
assert(input_id >= 0 && input_id < (int)parameter_class_.size());
104+
if (parameter_class_[input_id].is_symbolic()) {
89105
is_symbolic = true;
90106
}
107+
if (!parameter_class_[input_id].is_const()) {
108+
is_const = false;
109+
}
91110
}
92111
if (!is_symbolic) {
93112
// A concrete parameter, no need to create an expression.
@@ -99,9 +118,9 @@ int ParamInfo::get_new_param_expression_id(
99118
}
100119
return get_new_param_id(op->compute(input_params));
101120
}
102-
int id = (int)is_parameter_symbolic_.size();
103-
is_parameter_symbolic_.push_back(true);
104-
is_parameter_halved_.push_back(false);
121+
int id = (int)parameter_class_.size();
122+
parameter_class_.emplace_back(is_const ? ParamClass::symbolic_constexpr
123+
: ParamClass::expression);
105124
auto circuit_gate = std::make_unique<CircuitGate>();
106125
circuit_gate->gate = op;
107126
for (auto &input_id : parameter_indices) {
@@ -119,21 +138,21 @@ int ParamInfo::get_new_param_expression_id(
119138
}
120139

121140
int ParamInfo::get_num_parameters() const {
122-
return (int)is_parameter_symbolic_.size();
141+
return (int)parameter_class_.size();
123142
}
124143

125144
int ParamInfo::get_num_input_symbolic_parameters() const {
126145
return (int)random_parameters_.size();
127146
}
128147

129148
bool ParamInfo::param_is_symbolic(int id) const {
130-
return id >= 0 && id < (int)is_parameter_symbolic_.size() &&
131-
is_parameter_symbolic_[id];
149+
return id >= 0 && id < (int)parameter_class_.size() &&
150+
parameter_class_[id].is_symbolic();
132151
}
133152

134-
bool ParamInfo::param_has_value(int id) const {
135-
return id >= 0 && id < (int)is_parameter_symbolic_.size() &&
136-
!is_parameter_symbolic_[id];
153+
bool ParamInfo::param_is_const(int id) const {
154+
return id >= 0 && id < (int)parameter_class_.size() &&
155+
parameter_class_[id].is_const();
137156
}
138157

139158
bool ParamInfo::param_is_expression(int id) const {
@@ -142,8 +161,9 @@ bool ParamInfo::param_is_expression(int id) const {
142161
}
143162

144163
bool ParamInfo::param_is_halved(int id) const {
145-
return id >= 0 && id < (int)is_parameter_halved_.size() &&
146-
is_parameter_halved_[id];
164+
return id >= 0 && id < (int)parameter_class_.size() &&
165+
parameter_class_[id] == ParamClass::symbolic_halved;
166+
// TODO: halved parameter expressions
147167
}
148168

149169
CircuitWire *ParamInfo::get_param_wire(int id) const {
@@ -158,10 +178,10 @@ std::vector<ParamType>
158178
ParamInfo::compute_parameters(const std::vector<ParamType> &input_parameters) {
159179
// Creates a param list, assuming that all symbolic params are defined first.
160180
auto result = input_parameters;
161-
result.resize(is_parameter_symbolic_.size());
181+
result.resize(parameter_class_.size());
162182
// Populates constant parameters.
163183
for (int i = 0; i < result.size(); ++i) {
164-
if (!is_parameter_symbolic_[i]) {
184+
if (parameter_class_[i].is_const()) {
165185
result[i] = parameter_values_[i];
166186
}
167187
}
@@ -180,7 +200,7 @@ ParamInfo::compute_parameters(const std::vector<ParamType> &input_parameters) {
180200
}
181201

182202
std::vector<InputParamMaskType> ParamInfo::get_param_masks() const {
183-
std::vector<InputParamMaskType> param_mask(is_parameter_symbolic_.size());
203+
std::vector<InputParamMaskType> param_mask(parameter_class_.size());
184204
for (int i = 0; i < (int)param_mask.size(); i++) {
185205
if (!param_is_expression(i)) {
186206
param_mask[i] = ((InputParamMaskType)1) << i;
@@ -195,4 +215,35 @@ std::vector<InputParamMaskType> ParamInfo::get_param_masks() const {
195215
}
196216
return param_mask;
197217
}
218+
219+
std::string ParamInfo::to_json() const {
220+
std::string result = "[";
221+
result += "[";
222+
result += std::to_string(parameter_class_.size());
223+
for (int i = 0; i < (int)parameter_class_.size(); i++) {
224+
result += ", ";
225+
if (parameter_class_[i] == ParamClass::arithmetic_int) {
226+
// arithmetic int
227+
result += std::to_string((int)parameter_values_[i]);
228+
} else if (parameter_class_[i].is_input()) {
229+
if (parameter_class_[i].is_symbolic()) {
230+
// input symbolic
231+
result += "\"\"";
232+
// TODO: halved parameter
233+
} else {
234+
// input concrete
235+
result += to_string_with_precision(parameter_values_[i],
236+
/*precision=*/17);
237+
}
238+
} else {
239+
// expression
240+
result += parameter_wires_[i]->input_gates[0]->to_json();
241+
}
242+
}
243+
result += "], ";
244+
result += to_json_style_string_with_precision(random_parameters_,
245+
/*precision=*/17);
246+
result += "]";
247+
return result;
248+
}
198249
} // namespace quartz

0 commit comments

Comments
 (0)