Skip to content

Commit

Permalink
General annotation adapter (#26)
Browse files Browse the repository at this point in the history
* Use generic annotation for cost
* Implement tee in terms of AnnotatedIterable

Unrelated changes:
* watch only iterative-methods doc files
* remove chaining from cg_demo
* ignore unstable test
* Add some doc comments
  • Loading branch information
daniel-vainsencher authored Mar 17, 2021
1 parent 2a5fe34 commit eef4f71
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 140 deletions.
44 changes: 21 additions & 23 deletions examples/conjugate_gradient_method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,42 +10,40 @@ use iterative_methods::*;
fn cg_demo() {
let p = make_3x3_psd_system_2();
println!("a: \n{}", &p.a);
let cg_iter = CGIterable::conjugate_gradient(p)
// Upper bound the number of iterations
.take(20)
// Apply a quality based stopping condition; this relies on
// algorithm internals, requiring all state to be exposed and
// not just the result.
.take_while(|cgi| cgi.rsprev.sqrt() > 1e-6);
// Because time, tee are not part of the StreamingIterator trait,
// they cannot be chained syntactically as in the above.
let cg_iter = CGIterable::conjugate_gradient(p);
// Upper bound the number of iterations
let cg_iter = cg_iter.take(20);
// Apply a quality based stopping condition; this relies on
// algorithm internals, requiring all state to be exposed and
// not just the result.
let cg_iter = cg_iter.take_while(|cgi| cgi.rsprev.sqrt() > 1e-6);

// TODO can this be fixed? see iterutils crate.

//Note the side effect of tee is applied exactly to every x
// Note the side effect of tee is applied exactly to every x
// produced above, the sequence of which is not affected at
// all. This is just like applying a side effect inside the while
// loop, except we can compose multiple tee, each with its own
// effect.
let step_by_cg_iter = step_by(cg_iter, 2);
let timed_cg_iter = time(step_by_cg_iter);
let cg_iter = step_by(cg_iter, 2);
let cg_iter = time(cg_iter);

// We are assessing after timing, which means that computing this
// function is excluded from the duration measurements, which can
// be important in other cases.
let ct_cg_iter = assess(timed_cg_iter, |TimedResult { result, .. }| {
let res = result.a.dot(&result.x) - &result.b;
res.dot(&res)
});
// function is excluded from the duration measurements, which is
// generally the right way to do it, though not important here.
fn score(TimedResult { result, .. }: &TimedResult<CGIterable>) -> f64 {
result.rs
}

let cg_iter = assess(cg_iter, score);
let mut cg_print_iter = tee(
ct_cg_iter,
|CostResult {
cg_iter,
|AnnotatedResult {
result:
TimedResult {
result,
start_time,
duration,
},
cost,
annotation: cost,
}| {
let res = result.a.dot(&result.x) - &result.b;
println!(
Expand Down
4 changes: 2 additions & 2 deletions feedback_loop.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# prereq: sudo npm install -g browser-sync
browser-sync start --ss target/doc -s target/doc --directory --no-open --no-inject-changes --watch &
cargo watch -x check -x fmt -x doc -x clippy -x build -x test
browser-sync start --ss target/doc/iterative_methods -s target/doc/iterative_methods --directory --no-open --no-inject-changes --watch &
cargo watch -x check -x fmt -x doc -x clippy -x build -x test
18 changes: 2 additions & 16 deletions src/algorithms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pub mod cg_method {

pub fn solve_approximately(p: LinearSystem) -> V {
let solution = CGIterable::conjugate_gradient(p).take(200);
last(solution.map(|s| s.x.clone()))
last(solution.map(|s| s.x.clone())).expect("CGIterable should always return a solution.")
}

pub fn show_progress(p: LinearSystem) {
Expand All @@ -101,7 +101,6 @@ pub mod cg_method {
mod tests {

use super::cg_method::*;
use crate::last;
use crate::utils::make_3x3_psd_system;
use crate::utils::make_3x3_psd_system_1;
use crate::utils::LinearSystem;
Expand Down Expand Up @@ -180,20 +179,6 @@ mod tests {
}
}

#[test]
fn test_last() {
let v = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let iter = convert(v.clone());
assert!(last(iter) == 9);
}

#[test]
#[should_panic(expected = "StreamingIterator last expects at least one non-None element.")]
fn test_last_fail() {
let v: Vec<u32> = vec![];
last(convert(v.clone()));
}

#[test]
fn cg_simple_test() {
let p = make_3x3_psd_system_1();
Expand Down Expand Up @@ -232,6 +217,7 @@ mod tests {
assert!(!result.is_error());
}

#[ignore]
#[test]
fn cg_rank_one_v() {
// This test is currently discarded by test_arbitrary_3x3_pd
Expand Down
Loading

0 comments on commit eef4f71

Please sign in to comment.