|
5 | 5 | #include <gtest/gtest.h> |
6 | 6 |
|
7 | 7 | #include "core/c3_miqp.h" |
| 8 | +#include "core/c3_qp.h" |
8 | 9 | #include "core/test/c3_cartpole_problem.hpp" |
9 | 10 |
|
10 | 11 | #include "drake/math/discrete_algebraic_riccati_equation.h" |
@@ -331,37 +332,38 @@ TEST_F(C3CartpoleTest, ZSolStaleTest) { |
331 | 332 | } |
332 | 333 | } |
333 | 334 |
|
334 | | -// Test the cartpole example |
335 | | -// This test will take some time to complete ~30s |
336 | | -TEST_F(C3CartpoleTest, End2EndCartpoleTest) { |
337 | | - /// initialize ADMM variables (delta, w) |
338 | | - std::vector<VectorXd> delta(N, VectorXd::Zero(n + m + k)); |
339 | | - std::vector<VectorXd> w(N, VectorXd::Zero(n + m + k)); |
| 335 | +template <typename T> |
| 336 | +class C3CartpoleTypedTest : public testing::Test, public C3CartpoleProblem { |
| 337 | + protected: |
| 338 | + C3CartpoleTypedTest() |
| 339 | + : C3CartpoleProblem(0.411, 0.978, 0.6, 0.4267, 0.35, -0.35, 100, 9.81) { |
| 340 | + pOpt = std::make_unique<T>(*pSystem, cost, xdesired, options); |
| 341 | + } |
| 342 | + std::unique_ptr<T> pOpt; |
| 343 | +}; |
340 | 344 |
|
341 | | - /// initialize ADMM reset variables (delta, w are reseted to these values) |
342 | | - std::vector<VectorXd> delta_reset(N, VectorXd::Zero(n + m + k)); |
343 | | - std::vector<VectorXd> w_reset(N, VectorXd::Zero(n + m + k)); |
| 345 | +using projection_types = ::testing::Types<C3QP, C3MIQP>; |
| 346 | +TYPED_TEST_SUITE(C3CartpoleTypedTest, projection_types); |
344 | 347 |
|
| 348 | +// Test the cartpole example |
| 349 | +// This test will take some time to complete ~30s |
| 350 | +TYPED_TEST(C3CartpoleTypedTest, End2EndCartpoleTest) { |
345 | 351 | int timesteps = 1000; // number of timesteps for the simulation |
346 | 352 |
|
347 | 353 | /// create state and input arrays |
348 | | - std::vector<VectorXd> x(timesteps, VectorXd::Zero(n)); |
349 | | - std::vector<VectorXd> input(timesteps, VectorXd::Zero(k)); |
| 354 | + std::vector<VectorXd> x(timesteps, VectorXd::Zero(this->n)); |
| 355 | + std::vector<VectorXd> input(timesteps, VectorXd::Zero(this->k)); |
350 | 356 |
|
351 | | - x[0] = x0; |
| 357 | + x[0] = this->x0; |
352 | 358 |
|
353 | 359 | int close_to_zero_counter = 0; |
354 | 360 | for (int i = 0; i < timesteps - 1; i++) { |
355 | | - /// reset delta and w (default option) |
356 | | - delta = delta_reset; |
357 | | - w = w_reset; |
358 | | - |
359 | 361 | /// calculate the input given x[i] |
360 | | - pOpt->Solve(x[i]); |
361 | | - input[i] = pOpt->GetInputSolution()[0]; |
| 362 | + this->pOpt->Solve(x[i]); |
| 363 | + input[i] = this->pOpt->GetInputSolution()[0]; |
362 | 364 |
|
363 | 365 | /// simulate the LCS |
364 | | - x[i + 1] = pSystem->Simulate(x[i], input[i]); |
| 366 | + x[i + 1] = this->pSystem->Simulate(x[i], input[i]); |
365 | 367 | if (x[i + 1].isZero(0.1)) { |
366 | 368 | close_to_zero_counter++; |
367 | 369 | if (close_to_zero_counter == 30) break; |
|
0 commit comments