diff --git a/src/python/verifier/gates.py b/src/python/verifier/gates.py index fc066e44..61bc67a5 100644 --- a/src/python/verifier/gates.py +++ b/src/python/verifier/gates.py @@ -19,6 +19,53 @@ def neg(a): return cos_a, -sin_a +def mult(x, y): + # In this special case, both the lhs and the rhs are numbers. + if isinstance(x, (int, float)) and isinstance(y, (int, float)): + return x * y + + # To apply trigonometric angle formulas, one side must be a number. + # Without loss of generality, the left-hand side will be a number. + assert isinstance(x, (int, float)) or isinstance(y, (int, float)) + if isinstance(y, (int, float)): + x, y = y, x + + # This block ensures that the lhs is not only a number, but also an integer. + # This is because angle-reducing formula only exist for integer multipliers. + # Of course, other formulas exist, such as the half-angle formula. + # However, this formula is not determined (for arbitrary a) by (cos_a, sin_a) alone. + if isinstance(x, float): + assert x.is_integer() + x = int(x) + + # Moves negative signs from the left-hand side to the right-hand side. + if x < 0: + x = -x + y = neg(y) + + # Base Cases. + if x == 0: + return 1, 0 + elif x == 1: + return y + # Triple-angle formula. + elif x % 3 == 0: + cos_y, sin_y = mult(x // 3, y) + cos_z = 4 * cos_y * cos_y * cos_y - 3 * cos_y + sin_z = 3 * sin_y - 4 * sin_y * sin_y * sin_y + return cos_z, sin_z + # Double-angle formula. + elif x % 2 == 0: + cos_y, sin_y = mult(x // 2, y) + cos_z = cos_y * cos_y - sin_y * sin_y + sin_z = 2 * cos_y * sin_y + return cos_z, sin_z + # Otherwise, use the sum formula to decrease x by 1. + else: + z = mult(x - 1, y) + return add(y, z) + + # quantum gates diff --git a/src/quartz/gate/all_gates.h b/src/quartz/gate/all_gates.h index 55f1d905..a1cb0a0d 100644 --- a/src/quartz/gate/all_gates.h +++ b/src/quartz/gate/all_gates.h @@ -13,6 +13,7 @@ #include "h.h" #include "input_param.h" #include "input_qubit.h" +#include "mult.h" #include "neg.h" #include "p.h" #include "pdg.h" diff --git a/src/quartz/gate/gates.inc.h b/src/quartz/gate/gates.inc.h index bb5d7337..2e4d70ad 100644 --- a/src/quartz/gate/gates.inc.h +++ b/src/quartz/gate/gates.inc.h @@ -10,6 +10,7 @@ PER_GATE(rz, RZGate) PER_GATE(cx, CXGate) PER_GATE(ccx, CCXGate) PER_GATE(add, AddGate) +PER_GATE(mult, MultGate) PER_GATE(neg, NegGate) PER_GATE(z, ZGate) PER_GATE(s, SGate) diff --git a/src/quartz/gate/mult.h b/src/quartz/gate/mult.h new file mode 100644 index 00000000..22241200 --- /dev/null +++ b/src/quartz/gate/mult.h @@ -0,0 +1,18 @@ +#pragma once + +#include "gate.h" + +#include + +namespace quartz { +class MultGate : public Gate { + public: + MultGate() : Gate(GateType::mult, 0 /*num_qubits*/, 2 /*num_parameters*/) {} + ParamType compute(const std::vector &input_params) override { + assert(input_params.size() == 2); + return input_params[0] * input_params[1]; + } + bool is_commutative() const override { return true; } +}; + +} // namespace quartz diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index 99eaf718..b5eb73cc 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -31,6 +31,7 @@ file(GLOB_RECURSE TEST_FROM_AND_TO_QASM "test_from_and_to_qasm.cpp") file(GLOB_RECURSE TEST_OPTIMIZE "test_optimize.cpp") file(GLOB_RECURSE TEST_CREATE_GRAPHXFER_FROM_QASM "test_create_graphXfer_from_qasm.cpp") file(GLOB_RECURSE TEST_PARTITION "test_partition.cpp") +file(GLOB_RECURSE TEST_MULT "test_mult.cpp") if(USE_ARBLIB) file(GLOB_RECURSE TEST_ARB "test_arb.cpp") endif() @@ -66,6 +67,7 @@ add_executable(test_from_and_to_qasm ${TEST_FROM_AND_TO_QASM} ) add_executable(test_optimize ${TEST_OPTIMIZE} ) add_executable(test_create_graphXfer_from_qasm ${TEST_CREATE_GRAPHXFER_FROM_QASM} ) add_executable(test_partition ${TEST_PARTITION} ) +add_executable(test_mult ${TEST_MULT} ) if(USE_ARBLIB) add_executable(test_arb ${TEST_ARB} ) endif() diff --git a/src/test/test_mult.cpp b/src/test/test_mult.cpp new file mode 100644 index 00000000..40ce0b42 --- /dev/null +++ b/src/test/test_mult.cpp @@ -0,0 +1,31 @@ +#include "quartz/circuitseq/circuitseq.h" +#include "quartz/context/context.h" +#include "quartz/gate/gate.h" + +#include + +using namespace quartz; + +int main() { + ParamInfo param_info(0); + Context ctx({GateType::rx, GateType::mult}, 1, ¶m_info); + + auto p0 = ctx.get_new_param_id(2.0); + auto p1 = ctx.get_new_param_id(3.0); + auto p2 = ctx.get_new_param_id(6.0); + auto p3 = + ctx.get_new_param_expression_id({p0, p1}, ctx.get_gate(GateType::mult)); + + CircuitSeq dag1(1); + dag1.add_gate({0}, {p2}, ctx.get_gate(GateType::rx), &ctx); + + CircuitSeq dag2(1); + dag2.add_gate({0}, {p3}, ctx.get_gate(GateType::rx), &ctx); + + auto c1 = dag1.to_qasm_style_string(&ctx); + auto c2 = dag2.to_qasm_style_string(&ctx); + assert(c1 == c2); + + // Working directory is cmake-build-debug/ here. + system("python ../src/test/test_mult.py"); +} diff --git a/src/test/test_mult.py b/src/test/test_mult.py new file mode 100644 index 00000000..770979f9 --- /dev/null +++ b/src/test/test_mult.py @@ -0,0 +1,62 @@ +import sys + +sys.path.append("..") + +from src.python.verifier.gates import * + + +def approx_eq(a, b): + assert len(a) == 2 + assert len(b) == 2 + cos_a, sin_a = a + cos_b, sin_b = b + + err = max(abs(cos_a - cos_b), abs(sin_a - sin_b)) + return err < 0.0000000000001 + + +def mult_test(expected, n, a): + actual = mult(n, a) + swapped = mult(a, n) + + assert actual == swapped + assert approx_eq(actual, expected) + + +def test_positive(a): + expected = 1, 0 + for n in range(0, 41): + mult_test(expected, n, a) + expected = add(a, expected) + + +def test_negative(a): + expected = neg(a) + for n in range(-1, -41, -1): + mult_test(expected, n, a) + expected = add(neg(a), expected) + + +def test_floats(a): + n = 5 + expected = mult(n, a) + + actual = mult(float(n), a) + swapped = mult(a, float(n)) + + assert actual == swapped + assert actual == expected + + +def test_numbers(): + assert mult(2, 3.0) == 6.0 + assert mult(3.0, 2) == 6.0 + assert mult(3, 4) == 12 + assert mult(3.0, 0.1) == 0.3 + + +if __name__ == '__main__': + v = 1 / math.sqrt(2) + test_positive((v, v)) + test_negative((v, v)) + test_floats((v, v))