Skip to content

Commit

Permalink
Fixup the ConjugateGradient constructor and adapt blog code (#46)
Browse files Browse the repository at this point in the history
Also prettify example output
  • Loading branch information
daniel-vainsencher authored May 25, 2021
1 parent 169b9ac commit 799dcb4
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 41 deletions.
4 changes: 2 additions & 2 deletions examples/conjugate_gradient_method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ extern crate eigenvalues;
extern crate nalgebra as na;
use streaming_iterator::*;

use iterative_methods::conjugate_gradient::{conjugate_gradient, ConjugateGradient};
use iterative_methods::conjugate_gradient::ConjugateGradient;
use iterative_methods::utils::make_3x3_pd_system_2;
use iterative_methods::*;

Expand Down Expand Up @@ -35,7 +35,7 @@ fn cg_demo() {
let optimum = rcarr1(&[-4.0, 6., -4.]);

// Initialize the conjugate gradient solver on this problem
let cg_iter = conjugate_gradient(&p);
let cg_iter = ConjugateGradient::for_problem(&p);

// Cap the number of iterations.
let cg_iter = cg_iter.take(80);
Expand Down
15 changes: 10 additions & 5 deletions examples/conjugate_gradient_method_for_blog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ extern crate eigenvalues;
extern crate nalgebra as na;
use streaming_iterator::*;

use iterative_methods::conjugate_gradient::{conjugate_gradient, ConjugateGradient};
use iterative_methods::conjugate_gradient::ConjugateGradient;
use iterative_methods::utils::make_3x3_pd_system_2;
use iterative_methods::*;

Expand All @@ -21,7 +21,7 @@ fn cg_demo_pt1() {
let p = make_3x3_pd_system_2();

// Next convert it into an iterator
let mut cg_iter = conjugate_gradient(&p);
let mut cg_iter = ConjugateGradient::for_problem(&p);

// and loop over intermediate solutions.
// Note `next` is provided by the StreamingIterator trait using
Expand All @@ -40,11 +40,12 @@ fn cg_demo_pt1() {
// || ... ||_2 is notation for euclidean length of what
// lies between the vertical lines.
println!(
"||Ax - b||_2 = {:.5}, for x = {:.4}, residual = {:.7}",
"||Ax - b||_2 = {:.5}, for x = {:+.3}, residual = {:+.3}",
res_squared_length.sqrt(),
result.solution,
res
);
// Stop if residual is small enough
if res_squared_length < 1e-3 {
break;
}
Expand All @@ -61,7 +62,7 @@ fn residual_l2(result: &ConjugateGradient) -> f64 {
/// https://daniel-vainsencher.github.io/book/iterative_methods_part_2.html
fn cg_demo_pt2_1() {
let p = make_3x3_pd_system_2();
let cg_iter = conjugate_gradient(&p);
let cg_iter = ConjugateGradient::for_problem(&p);

// Annotate each approximate solution with its cost
let cg_iter = assess(cg_iter, residual_l2);
Expand Down Expand Up @@ -97,14 +98,16 @@ fn cg_demo_pt2_2() {
// Set up a problem for which we happen to know the solution
let p = make_3x3_pd_system_2();
let optimum = rcarr1(&[-4.0, 6., -4.]);
let cg_iter = ConjugateGradient::for_problem(&p);

let cg_iter = conjugate_gradient(&p);
// Cap the number of iterations.
let cg_iter = cg_iter.take(80);

// Time each iteration, only of preceding steps (the method)
// excluding downstream evaluation and I/O (tracking overhead), as
// well as elapsed clocktime (combining both).
let cg_iter = time(cg_iter);

// Record multiple measures of quality
let cg_iter = assess(cg_iter, |TimedResult { result, .. }| {
(
Expand All @@ -113,11 +116,13 @@ fn cg_demo_pt2_2() {
a_distance(result, optimum.clone()),
)
});

// Stop if converged by both criteria
fn small_residual((euc, linf, _): &(f64, f64, f64)) -> bool {
euc < &1e-3 && linf < &1e-3
}
let mut cg_iter = take_until(cg_iter, |ar| small_residual(&ar.annotation));

// Output progress
while let Some(AnnotatedResult {
annotation: (euc, linf, a_dist),
Expand Down
8 changes: 4 additions & 4 deletions src/algorithms.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub use crate::conjugate_gradient::conjugate_gradient;
pub use crate::conjugate_gradient::ConjugateGradient;
use ndarray::ArcArray1;
use ndarray::ArcArray2;
pub type S = f64;
Expand All @@ -9,7 +9,7 @@ pub type V = ArcArray1<S>;
#[cfg(test)]
mod tests {

use crate::conjugate_gradient::conjugate_gradient;
use crate::conjugate_gradient::ConjugateGradient;
use crate::inspect;
use crate::last;
use crate::utils::make_3x3_pd_system_1;
Expand All @@ -25,12 +25,12 @@ mod tests {
use streaming_iterator::StreamingIterator;

pub fn solve_approximately(p: LinearSystem) -> V {
let solution = conjugate_gradient(&p).take(20);
let solution = ConjugateGradient::for_problem(&p).take(20);
last(solution.map(|s| s.x_k.clone())).expect("CGIterable should always return a solution.")
}

pub fn show_progress(p: LinearSystem) {
let cg_iter = conjugate_gradient(&p).take(20);
let cg_iter = ConjugateGradient::for_problem(&p).take(20);
let mut cg_print_iter = inspect(cg_iter, |result| {
//println!("result: {:?}", result);
let res = result.a.dot(&result.solution) - &result.b;
Expand Down
58 changes: 30 additions & 28 deletions src/conjugate_gradient.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,35 +75,37 @@ pub struct ConjugateGradient {
pub pap_km: S,
}

/// Initialize a conjugate gradient iterative solver
pub fn conjugate_gradient(p: &LinearSystem) -> ConjugateGradient {
let x_0 = match &p.x0 {
Some(x) => x.clone(),
None => ArrayBase::zeros(p.a.shape()[0]),
};
impl ConjugateGradient {
/// Initialize a conjugate gradient iterative solver to solve linear system `p`.
pub fn for_problem(p: &LinearSystem) -> ConjugateGradient {
let x_0 = match &p.x0 {
Some(x) => x.clone(),
None => ArrayBase::zeros(p.a.shape()[0]),
};

// Set r_0 = A*x_0 - b and p_0 =-r_0, k=0
let r_k = (&p.a.dot(&x_0) - &p.b).to_shared();
let r_k2 = r_k.dot(&r_k);
let r_km2 = NAN;
let p_k = -r_k.clone();
let ap_k = p.a.dot(&p_k).to_shared();
let pap_k = p_k.dot(&ap_k);
let pap_km = NAN;
ConjugateGradient {
x_k: x_0.clone(),
solution: x_0,
a: p.a.clone(),
b: p.b.clone(),
r_k,
r_k2,
r_km2,
p_k,
ap_k,
pap_k,
pap_km,
alpha_k: NAN,
beta_k: NAN,
// Set r_0 = A*x_0 - b and p_0 =-r_0, k=0
let r_k = (&p.a.dot(&x_0) - &p.b).to_shared();
let r_k2 = r_k.dot(&r_k);
let r_km2 = NAN;
let p_k = -r_k.clone();
let ap_k = p.a.dot(&p_k).to_shared();
let pap_k = p_k.dot(&ap_k);
let pap_km = NAN;
ConjugateGradient {
x_k: x_0.clone(),
solution: x_0,
a: p.a.clone(),
b: p.b.clone(),
r_k,
r_k2,
r_km2,
p_k,
ap_k,
pap_k,
pap_km,
alpha_k: NAN,
beta_k: NAN,
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::utils::{
expose_w, generate_step_stream, make_3x3_pd_system_1, read_yaml_to_string, Counter,
};
use iterative_methods::algorithms::conjugate_gradient;
use iterative_methods::algorithms::ConjugateGradient;
use iterative_methods::*;
extern crate streaming_iterator;
use crate::streaming_iterator::*;
Expand All @@ -13,7 +13,7 @@ use rand_pcg::Pcg64;
#[test]
fn test_timed_iterable() {
let p = make_3x3_pd_system_1();
let cg_iter = conjugate_gradient(&p).take(50);
let cg_iter = ConjugateGradient::for_problem(&p).take(50);
let cg_timed_iter = time(cg_iter);
let mut start_times = Vec::new();
let mut durations = Vec::new();
Expand Down

0 comments on commit 799dcb4

Please sign in to comment.