diff --git a/examples/conjugate_gradient_method.rs b/examples/conjugate_gradient_method.rs index 1319503..093c4d4 100644 --- a/examples/conjugate_gradient_method.rs +++ b/examples/conjugate_gradient_method.rs @@ -10,42 +10,40 @@ use iterative_methods::*; fn cg_demo() { let p = make_3x3_psd_system_2(); println!("a: \n{}", &p.a); - let cg_iter = CGIterable::conjugate_gradient(p) - // Upper bound the number of iterations - .take(20) - // Apply a quality based stopping condition; this relies on - // algorithm internals, requiring all state to be exposed and - // not just the result. - .take_while(|cgi| cgi.rsprev.sqrt() > 1e-6); - // Because time, tee are not part of the StreamingIterator trait, - // they cannot be chained syntactically as in the above. + let cg_iter = CGIterable::conjugate_gradient(p); + // Upper bound the number of iterations + let cg_iter = cg_iter.take(20); + // Apply a quality based stopping condition; this relies on + // algorithm internals, requiring all state to be exposed and + // not just the result. + let cg_iter = cg_iter.take_while(|cgi| cgi.rsprev.sqrt() > 1e-6); - // TODO can this be fixed? see iterutils crate. - - //Note the side effect of tee is applied exactly to every x + // Note the side effect of tee is applied exactly to every x // produced above, the sequence of which is not affected at // all. This is just like applying a side effect inside the while // loop, except we can compose multiple tee, each with its own // effect. - let step_by_cg_iter = step_by(cg_iter, 2); - let timed_cg_iter = time(step_by_cg_iter); + let cg_iter = step_by(cg_iter, 2); + let cg_iter = time(cg_iter); + // We are assessing after timing, which means that computing this - // function is excluded from the duration measurements, which can - // be important in other cases. - let ct_cg_iter = assess(timed_cg_iter, |TimedResult { result, .. }| { - let res = result.a.dot(&result.x) - &result.b; - res.dot(&res) - }); + // function is excluded from the duration measurements, which is + // generally the right way to do it, though not important here. + fn score(TimedResult { result, .. }: &TimedResult) -> f64 { + result.rs + } + + let cg_iter = assess(cg_iter, score); let mut cg_print_iter = tee( - ct_cg_iter, - |CostResult { + cg_iter, + |AnnotatedResult { result: TimedResult { result, start_time, duration, }, - cost, + annotation: cost, }| { let res = result.a.dot(&result.x) - &result.b; println!( diff --git a/feedback_loop.sh b/feedback_loop.sh index 94dfeac..1095348 100755 --- a/feedback_loop.sh +++ b/feedback_loop.sh @@ -1,3 +1,3 @@ # prereq: sudo npm install -g browser-sync -browser-sync start --ss target/doc -s target/doc --directory --no-open --no-inject-changes --watch & -cargo watch -x check -x fmt -x doc -x clippy -x build -x test +browser-sync start --ss target/doc/iterative_methods -s target/doc/iterative_methods --directory --no-open --no-inject-changes --watch & +cargo watch -x check -x fmt -x doc -x clippy -x build -x test diff --git a/src/algorithms.rs b/src/algorithms.rs index 1ecd857..4690022 100644 --- a/src/algorithms.rs +++ b/src/algorithms.rs @@ -75,7 +75,7 @@ pub mod cg_method { pub fn solve_approximately(p: LinearSystem) -> V { let solution = CGIterable::conjugate_gradient(p).take(200); - last(solution.map(|s| s.x.clone())) + last(solution.map(|s| s.x.clone())).expect("CGIterable should always return a solution.") } pub fn show_progress(p: LinearSystem) { @@ -101,7 +101,6 @@ pub mod cg_method { mod tests { use super::cg_method::*; - use crate::last; use crate::utils::make_3x3_psd_system; use crate::utils::make_3x3_psd_system_1; use crate::utils::LinearSystem; @@ -180,20 +179,6 @@ mod tests { } } - #[test] - fn test_last() { - let v = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; - let iter = convert(v.clone()); - assert!(last(iter) == 9); - } - - #[test] - #[should_panic(expected = "StreamingIterator last expects at least one non-None element.")] - fn test_last_fail() { - let v: Vec = vec![]; - last(convert(v.clone())); - } - #[test] fn cg_simple_test() { let p = make_3x3_psd_system_1(); @@ -232,6 +217,7 @@ mod tests { assert!(!result.is_error()); } + #[ignore] #[test] fn cg_rank_one_v() { // This test is currently discarded by test_arbitrary_3x3_pd diff --git a/src/lib.rs b/src/lib.rs index 146d60d..9b51c26 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,48 +12,56 @@ use streaming_iterator::*; pub mod algorithms; pub mod utils; -/// Annotate the underlying items with a cost (non-negative f64) as -/// given by a function. -pub struct CostIterable -where - I: StreamingIterator, -{ - it: I, - f: F, - last: Option>, -} - -/// Store the cost of a state. Lower costs are better. +/// Store a generic annotation next to the state. #[derive(Clone)] -pub struct CostResult { +pub struct AnnotatedResult { pub result: T, - pub cost: f64, + pub annotation: A, +} + +/// An adaptor that annotates every underlying item `x` with `f(x)`. +pub struct AnnotatedIterable +where + I: Sized + StreamingIterator, + T: Clone, + F: FnMut(&T) -> A, +{ + pub it: I, + pub f: F, + pub current: Option>, } -pub fn assess(it: I, f: F) -> CostIterable +impl AnnotatedIterable where I: StreamingIterator, - F: FnMut(&I::Item) -> f64, + T: Sized + Clone, + F: FnMut(&T) -> A, { - CostIterable { it, f, last: None } + /// Annotate every underlying item with the result of applying `f` to it. + fn new(it: I, f: F) -> AnnotatedIterable { + AnnotatedIterable { + it, + f: f, + current: None, + } + } } -impl StreamingIterator for CostIterable +impl StreamingIterator for AnnotatedIterable where I: StreamingIterator, T: Sized + Clone, - F: FnMut(&T) -> f64, + F: FnMut(&T) -> A, { - type Item = CostResult; + type Item = AnnotatedResult; fn advance(&mut self) { - let _before = Instant::now(); self.it.advance(); - self.last = match self.it.get() { + self.current = match self.it.get() { Some(n) => { - let cost = (self.f)(n); - Some(CostResult { - cost, + let annotation = (self.f)(n); + Some(AnnotatedResult { + annotation, result: n.clone(), }) } @@ -62,64 +70,40 @@ where } fn get(&self) -> Option<&Self::Item> { - match &self.last { + match &self.current { Some(tr) => Some(&tr), None => None, } } } -/// Pass the values from the streaming iterator through, running a -/// function on each for side effects. -pub struct Tee { - pub it: I, - pub f: F, +/// Annotate every underlying item with its score, as defined by `f`. +pub fn assess(it: I, f: F) -> AnnotatedIterable +where + T: Clone, + F: FnMut(&T) -> f64, + I: StreamingIterator, +{ + AnnotatedIterable::new(it, f) } -/* -// TODO: For ideal convenience, this should be implemented inside the StreamingIterator trait. -impl StreamingIterator { - fn tee(self, f: F) -> Tee - where - Self: Sized, - F: Fn(&Self::Item) - { - Tee { - it: self, - f: f - } - } -} */ - -pub fn tee(it: I, f: F) -> Tee +/// Apply `f` to every underlying item. +pub fn tee(it: I, f: F) -> AnnotatedIterable where I: Sized + StreamingIterator, F: FnMut(&T), + T: Clone, { - Tee { it: it, f: f } + AnnotatedIterable::new(it, f) } -impl StreamingIterator for Tee +/// Get the item before the first None, assuming any exist. +pub fn last(it: I) -> Option where - I: StreamingIterator, - F: FnMut(&I::Item), + I: StreamingIterator, + T: Sized + Clone, { - type Item = I::Item; - - #[inline] - fn advance(&mut self) { - // The side effect happens exactly once for each new value - // generated. - self.it.advance(); - if let Some(x) = self.it.get() { - (self.f)(x); - } - } - - #[inline] - fn get(&self) -> Option<&I::Item> { - self.it.get() - } + it.fold(None, |_acc, i| Some((*i).clone())) } /// Times every call to `advance` on the underlying @@ -128,6 +112,7 @@ where pub struct TimedIterable where I: StreamingIterator, + T: Clone, { it: I, current: Option>, @@ -145,36 +130,24 @@ pub struct TimedResult { pub duration: Duration, } -pub fn last(it: I) -> T -where - I: StreamingIterator, - T: Sized + Clone, -{ - let last_some = it.fold(None, |_acc, i| Some((*i).clone())); - let last_item = last_some - .expect("StreamingIterator last expects at least one non-None element.") - .clone(); - last_item -} - /// Wrap each value of a streaming iterator with the durations: /// - between the call to this function and start of the value's computation /// - it took to calculate that value pub fn time(it: I) -> TimedIterable where I: Sized + StreamingIterator, - T: Sized, + T: Sized + Clone, { TimedIterable { it: it, - timer: Instant::now(), current: None, + timer: Instant::now(), } } impl StreamingIterator for TimedIterable where - I: StreamingIterator, + I: Sized + StreamingIterator, T: Sized + Clone, { type Item = TimedResult; @@ -379,15 +352,9 @@ where pub f: F, } -// NOTE: -// Either -// F: FnMut(&I::Item) -> f64, -// F: FnMut(&T) -> f64 -// compiles pub fn wd_iterable(it: I, f: F) -> WDIterable where I: StreamingIterator, - // F: FnMut(&I::Item) -> f64, F: FnMut(&T) -> f64, { WDIterable { @@ -397,16 +364,10 @@ where } } -// NOTE: -// Either -// F: FnMut(&I::Item) -> f64, -// F: FnMut(&T) -> f64 -// compiles impl StreamingIterator for WDIterable where I: StreamingIterator, F: FnMut(&T) -> f64, - // F: FnMut(&I::Item) -> f64, T: Sized + Clone, { type Item = WeightedDatum; @@ -491,7 +452,7 @@ pub struct WeightedReservoirIterable { rng: Pcg64, } -// Create a WeightedReservoirIterable +/// Create a random sample of the underlying weighted stream. pub fn weighted_reservoir_iterable( it: I, capacity: usize, @@ -592,14 +553,13 @@ mod tests { fn test_last() { let v = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; let iter = convert(v.clone()); - assert!(last(iter) == 9); + assert!(last(iter) == Some(9)); } #[test] - #[should_panic(expected = "StreamingIterator last expects at least one non-None element.")] - fn test_last_fail() { + fn test_last_none() { let v: Vec = vec![]; - last(convert(v.clone())); + assert!(last(convert(v.clone())) == None); } #[test] @@ -614,6 +574,23 @@ mod tests { } } + #[test] + fn annotate_test() { + let v = vec![0., 1., 2.]; + let iter = convert(v); + fn f(num: &f64) -> f64 { + num * 2. + } + let target_annotations = vec![0., 2., 4.]; + let mut annotations: Vec = Vec::with_capacity(3); + let mut ann_iter = AnnotatedIterable::new(iter, f); + while let Some(n) = ann_iter.next() { + annotations.push(n.annotation); + } + assert_eq!(annotations, target_annotations); + } + + /// Tests for the ReservoirIterable adaptor /// This test asserts that the reservoir is filled with the correct items. #[test] fn fill_reservoir_test() {