Skip to content

Commit 2bd1349

Browse files
committed
feat: reorganize
1 parent 204b0af commit 2bd1349

File tree

5 files changed

+27
-34
lines changed

5 files changed

+27
-34
lines changed

core/c3.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -388,9 +388,10 @@ void C3::ADMMStep(const VectorXd& x0, vector<VectorXd>* delta,
388388
}
389389
}
390390

391-
void C3::WarmStartQP(const Eigen::VectorXd& x0, int admm_iteration) {
392-
393-
if (admm_iteration == 0) return; // No warm start for the first iteration
391+
void C3::SetInitialGuessQP(const Eigen::VectorXd& x0, int admm_iteration) {
392+
prog_.SetInitialGuess(x_[0], x0);
393+
if (!warm_start_ || admm_iteration == 0)
394+
return; // No warm start for the first iteration
394395
int index = solve_time_ / lcs_.dt();
395396
double weight = (solve_time_ - index * lcs_.dt()) / lcs_.dt();
396397
for (int i = 0; i < N_ - 1; ++i) {
@@ -407,8 +408,8 @@ void C3::WarmStartQP(const Eigen::VectorXd& x0, int admm_iteration) {
407408
prog_.SetInitialGuess(x_[N_], warm_start_x_[admm_iteration - 1][N_]);
408409
}
409410

410-
void C3::ProcessQPResults(const MathematicalProgramResult& result,
411-
int admm_iteration, bool is_final_solve) {
411+
void C3::StoreQPResults(const MathematicalProgramResult& result,
412+
int admm_iteration, bool is_final_solve) {
412413
for (int i = 0; i < N_; ++i) {
413414
if (is_final_solve) {
414415
x_sol_->at(i) = result.GetSolution(x_[i]);
@@ -449,8 +450,7 @@ vector<VectorXd> C3::SolveQP(const VectorXd& x0, const vector<MatrixXd>& G,
449450
-2 * G.at(i) * WD.at(i));
450451
}
451452

452-
prog_.SetInitialGuess(x_[0], x0);
453-
if (warm_start_) WarmStartQP(x0, admm_iteration);
453+
SetInitialGuessQP(x0, admm_iteration);
454454

455455
MathematicalProgramResult result = osqp_.Solve(prog_);
456456

@@ -459,7 +459,7 @@ vector<VectorXd> C3::SolveQP(const VectorXd& x0, const vector<MatrixXd>& G,
459459
result.get_solution_result());
460460
}
461461

462-
ProcessQPResults(result, admm_iteration, is_final_solve);
462+
StoreQPResults(result, admm_iteration, is_final_solve);
463463

464464
return *z_sol_;
465465
}
@@ -476,7 +476,7 @@ vector<VectorXd> C3::SolveProjection(const vector<MatrixXd>& U,
476476
}
477477

478478
#pragma omp parallel for num_threads( \
479-
options_.num_threads) if (use_parallelization_in_projection_)
479+
options_.num_threads) if (use_parallelization_in_projection_)
480480
for (int i = 0; i < N_; ++i) {
481481
if (warm_start_) {
482482
if (i == N_ - 1) {

core/c3.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,8 @@ class C3 {
247247
const Eigen::MatrixXd& H, const Eigen::VectorXd& c,
248248
const int admm_iteration, const int& warm_start_index) = 0;
249249

250-
virtual void WarmStartQP(const Eigen::VectorXd& x0, int admm_iteration);
251-
virtual void ProcessQPResults(
250+
virtual void SetInitialGuessQP(const Eigen::VectorXd& x0, int admm_iteration);
251+
virtual void StoreQPResults(
252252
const drake::solvers::MathematicalProgramResult& result,
253253
int admm_iteration, bool is_final_solve);
254254
/*!

core/c3_plus.cc

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -65,30 +65,22 @@ void C3Plus::UpdateLCS(const LCS& lcs) {
6565
}
6666
}
6767

68-
void C3Plus::WarmStartQP(const Eigen::VectorXd& x0, int admm_iteration) {
69-
// if (admm_iteration == 0) return; // No warm start for the first iteration
68+
void C3Plus::SetInitialGuessQP(const Eigen::VectorXd& x0, int admm_iteration) {
69+
C3::SetInitialGuessQP(x0, admm_iteration);
70+
if (!warm_start_ || admm_iteration == 0)
71+
return; // No warm start for the first iteration
7072
int index = solve_time_ / lcs_.dt();
7173
double weight = (solve_time_ - index * lcs_.dt()) / lcs_.dt();
7274
for (int i = 0; i < N_ - 1; ++i) {
73-
prog_.SetInitialGuess(x_[i],
74-
(1 - weight) * warm_start_x_[admm_iteration][i] +
75-
weight * warm_start_x_[admm_iteration][i + 1]);
7675
prog_.SetInitialGuess(
77-
lambda_[i], (1 - weight) * warm_start_lambda_[admm_iteration][i] +
78-
weight * warm_start_lambda_[admm_iteration][i + 1]);
79-
prog_.SetInitialGuess(u_[i],
80-
(1 - weight) * warm_start_u_[admm_iteration][i] +
81-
weight * warm_start_u_[admm_iteration][i + 1]);
82-
prog_.SetInitialGuess(eta_[i],
83-
(1 - weight) * warm_start_eta_[admm_iteration][i] +
84-
weight * warm_start_eta_[admm_iteration][i + 1]);
76+
eta_[i], (1 - weight) * warm_start_eta_[admm_iteration - 1][i] +
77+
weight * warm_start_eta_[admm_iteration - 1][i + 1]);
8578
}
86-
prog_.SetInitialGuess(x_[N_], warm_start_x_[admm_iteration][N_]);
8779
}
8880

89-
void C3Plus::ProcessQPResults(const MathematicalProgramResult& result,
90-
int admm_iteration, bool is_final_solve) {
91-
C3::ProcessQPResults(result, admm_iteration, is_final_solve);
81+
void C3Plus::StoreQPResults(const MathematicalProgramResult& result,
82+
int admm_iteration, bool is_final_solve) {
83+
C3::StoreQPResults(result, admm_iteration, is_final_solve);
9284
for (int i = 0; i < N_; i++) {
9385
if (is_final_solve) {
9486
eta_sol_->at(i) = result.GetSolution(eta_[i]);

core/c3_plus.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,11 @@ class C3Plus final : public C3 {
5555
std::vector<std::vector<Eigen::VectorXd>> warm_start_eta_;
5656

5757
private:
58-
void ProcessQPResults(const drake::solvers::MathematicalProgramResult& result,
59-
int admm_iteration, bool is_final_solve) override;
58+
void StoreQPResults(const drake::solvers::MathematicalProgramResult& result,
59+
int admm_iteration, bool is_final_solve) override;
6060
void UpdateLCS(const LCS& lcs) override;
61-
void WarmStartQP(const Eigen::VectorXd& x0, int admm_iteration) override;
61+
void SetInitialGuessQP(const Eigen::VectorXd& x0,
62+
int admm_iteration) override;
6263
std::vector<drake::solvers::VectorXDecisionVariable> eta_;
6364
std::unique_ptr<std::vector<Eigen::VectorXd>> eta_sol_;
6465

multibody/geom_geom_collider.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ void GeomGeomCollider<T>::ComputeSphereMeshDistance(const Context<T>& context,
7878
mesh_set.Add(mesh_id);
7979
const auto sd_set = query_object.ComputeSignedDistanceGeometryToPoint(
8080
sphere_center, mesh_set);
81-
DRAKE_ASSERT(sd_set.size() > 0);
81+
DRAKE_DEMAND(sd_set.size() == 1);
8282
SignedDistanceToPoint<T> sd_to_point = sd_set[0];
8383

8484
// Compute contact distance and normal.
@@ -90,11 +90,11 @@ void GeomGeomCollider<T>::ComputeSphereMeshDistance(const Context<T>& context,
9090
nhat_BA_W = -nhat_BA_W;
9191
p_ACa = inspector.GetPoseInFrame(geometry_id_A_).template cast<T>() *
9292
sd_to_point.p_GN;
93-
p_BCb = X_FS.template cast<T>() * (-sphere_radius * nhat_BA_W);
93+
p_BCb = X_FS.template cast<T>() * (-1 * sphere_radius * nhat_BA_W);
9494
} else {
9595
p_BCb = inspector.GetPoseInFrame(geometry_id_B_).template cast<T>() *
9696
sd_to_point.p_GN;
97-
p_ACa = X_FS.template cast<T>() * (-sphere_radius * nhat_BA_W);
97+
p_ACa = X_FS.template cast<T>() * (-1 * sphere_radius * nhat_BA_W);
9898
}
9999
}
100100

0 commit comments

Comments
 (0)