11use  crate :: collections:: { HashMap ,  HashSet ,  VecDeque } ; 
2- use  crate :: tx_graph:: { TxAncestors ,  TxDescendants } ; 
2+ use  crate :: tx_graph:: { TxAncestors ,  TxDescendants ,   TxNode } ; 
33use  crate :: { Anchor ,  ChainOracle ,  TxGraph } ; 
44use  alloc:: boxed:: Box ; 
55use  alloc:: collections:: BTreeSet ; 
@@ -36,6 +36,9 @@ pub struct CanonicalIter<'g, A, C> {
3636    canonical :  CanonicalMap < A > , 
3737    not_canonical :  NotCanonicalSet , 
3838
39+     canonical_ancestors :  HashMap < Txid ,  Vec < Txid > > , 
40+     canonical_roots :  VecDeque < Txid > , 
41+ 
3942    queue :  VecDeque < Txid > , 
4043} 
4144
@@ -75,6 +78,8 @@ impl<'g, A: Anchor, C: ChainOracle> CanonicalIter<'g, A, C> {
7578            unprocessed_leftover_txs :  VecDeque :: new ( ) , 
7679            canonical :  HashMap :: new ( ) , 
7780            not_canonical :  HashSet :: new ( ) , 
81+             canonical_ancestors :  HashMap :: new ( ) , 
82+             canonical_roots :  VecDeque :: new ( ) , 
7883            queue :  VecDeque :: new ( ) , 
7984        } 
8085    } 
@@ -160,7 +165,7 @@ impl<'g, A: Anchor, C: ChainOracle> CanonicalIter<'g, A, C> {
160165
161166                // Any conflicts with a canonical tx can be added to `not_canonical`. Descendants 
162167                // of `not_canonical` txs can also be added to `not_canonical`. 
163-                 for  ( _,  conflict_txid)  in  self . tx_graph . direct_conflicts ( & tx)  { 
168+                 for  ( _,  conflict_txid)  in  self . tx_graph . direct_conflicts ( & tx. clone ( ) )  { 
164169                    TxDescendants :: new_include_root ( 
165170                        self . tx_graph , 
166171                        conflict_txid, 
@@ -181,6 +186,18 @@ impl<'g, A: Anchor, C: ChainOracle> CanonicalIter<'g, A, C> {
181186                    detected_self_double_spend = true ; 
182187                    return  None ; 
183188                } 
189+ 
190+                 // Calculates all the existing ancestors for the given Txid 
191+                 self . canonical_ancestors . insert ( 
192+                     this_txid, 
193+                     tx. clone ( ) 
194+                         . input 
195+                         . iter ( ) 
196+                         . filter ( |txin| self . tx_graph . get_tx ( txin. previous_output . txid ) . is_some ( ) ) 
197+                         . map ( |txin| txin. previous_output . txid ) 
198+                         . collect ( ) , 
199+                 ) ; 
200+ 
184201                canonical_entry. insert ( ( tx,  this_reason) ) ; 
185202                Some ( this_txid) 
186203            } , 
@@ -190,12 +207,29 @@ impl<'g, A: Anchor, C: ChainOracle> CanonicalIter<'g, A, C> {
190207        if  detected_self_double_spend { 
191208            for  txid in  staged_queue { 
192209                self . canonical . remove ( & txid) ; 
210+                 self . canonical_ancestors . remove ( & txid) ; 
193211            } 
194212            for  txid in  undo_not_canonical { 
195213                self . not_canonical . remove ( & txid) ; 
196214            } 
197215        }  else  { 
198-             self . queue . extend ( staged_queue) ; 
216+             // TODO: (@oleonardolima) Can this be optimized somehow ? 
217+             // Can we just do a simple lookup on the `canonical_ancestors` field ? 
218+             for  txid in  staged_queue { 
219+                 let  tx = self . tx_graph . get_tx ( txid) . expect ( "tx must exist" ) ; 
220+                 let  ancestors = tx
221+                     . input 
222+                     . iter ( ) 
223+                     . map ( |txin| txin. previous_output . txid ) 
224+                     . filter_map ( |prev_txid| self . tx_graph . get_tx ( prev_txid) ) 
225+                     . collect :: < Vec < _ > > ( ) ; 
226+ 
227+                 // check if it's a root: it's either a coinbase transaction or has not known 
228+                 // ancestors in the tx_graph 
229+                 if  tx. is_coinbase ( )  || ancestors. is_empty ( )  { 
230+                     self . canonical_roots . push_back ( txid) ; 
231+                 } 
232+             } 
199233        } 
200234    } 
201235} 
@@ -204,52 +238,58 @@ impl<A: Anchor, C: ChainOracle> Iterator for CanonicalIter<'_, A, C> {
204238    type  Item  = Result < ( Txid ,  Arc < Transaction > ,  CanonicalReason < A > ) ,  C :: Error > ; 
205239
206240    fn  next ( & mut  self )  -> Option < Self :: Item >  { 
207-         loop  { 
208-             if  let  Some ( txid)  = self . queue . pop_front ( )  { 
209-                 let  ( tx,  reason)  = self 
210-                     . canonical 
211-                     . get ( & txid) 
212-                     . cloned ( ) 
213-                     . expect ( "reason must exist" ) ; 
214-                 return  Some ( Ok ( ( txid,  tx,  reason) ) ) ; 
241+         while  let  Some ( ( txid,  tx) )  = self . unprocessed_assumed_txs . next ( )  { 
242+             if  !self . is_canonicalized ( txid)  { 
243+                 self . mark_canonical ( txid,  tx,  CanonicalReason :: assumed ( ) ) ; 
215244            } 
245+         } 
216246
217-             if  let  Some ( ( txid,  tx) )  = self . unprocessed_assumed_txs . next ( )  { 
218-                 if  !self . is_canonicalized ( txid)  { 
219-                     self . mark_canonical ( txid,  tx,  CanonicalReason :: assumed ( ) ) ; 
247+         while  let  Some ( ( txid,  tx,  anchors) )  = self . unprocessed_anchored_txs . next ( )  { 
248+             if  !self . is_canonicalized ( txid)  { 
249+                 if  let  Err ( err)  = self . scan_anchors ( txid,  tx,  anchors)  { 
250+                     return  Some ( Err ( err) ) ; 
220251                } 
221252            } 
253+         } 
222254
223-             if  let  Some ( ( txid,  tx,  anchors) )  = self . unprocessed_anchored_txs . next ( )  { 
224-                 if  !self . is_canonicalized ( txid)  { 
225-                     if  let  Err ( err)  = self . scan_anchors ( txid,  tx,  anchors)  { 
226-                         return  Some ( Err ( err) ) ; 
227-                     } 
228-                 } 
229-                 continue ; 
255+         while  let  Some ( ( txid,  tx,  last_seen) )  = self . unprocessed_seen_txs . next ( )  { 
256+             debug_assert ! ( 
257+                 !tx. is_coinbase( ) , 
258+                 "Coinbase txs must not have `last_seen` (in mempool) value" 
259+             ) ; 
260+             if  !self . is_canonicalized ( txid)  { 
261+                 let  observed_in = ObservedIn :: Mempool ( last_seen) ; 
262+                 self . mark_canonical ( txid,  tx,  CanonicalReason :: from_observed_in ( observed_in) ) ; 
230263            } 
264+         } 
231265
232-             if  let  Some ( ( txid,  tx,  last_seen) )  = self . unprocessed_seen_txs . next ( )  { 
233-                 debug_assert ! ( 
234-                     !tx. is_coinbase( ) , 
235-                     "Coinbase txs must not have `last_seen` (in mempool) value" 
236-                 ) ; 
237-                 if  !self . is_canonicalized ( txid)  { 
238-                     let  observed_in = ObservedIn :: Mempool ( last_seen) ; 
239-                     self . mark_canonical ( txid,  tx,  CanonicalReason :: from_observed_in ( observed_in) ) ; 
240-                 } 
241-                 continue ; 
266+         while  let  Some ( ( txid,  tx,  height) )  = self . unprocessed_leftover_txs . pop_front ( )  { 
267+             if  !self . is_canonicalized ( txid)  && !tx. is_coinbase ( )  { 
268+                 let  observed_in = ObservedIn :: Block ( height) ; 
269+                 self . mark_canonical ( txid,  tx,  CanonicalReason :: from_observed_in ( observed_in) ) ; 
242270            } 
271+         } 
243272
244-             if  let  Some ( ( txid,  tx,  height) )  = self . unprocessed_leftover_txs . pop_front ( )  { 
245-                 if  !self . is_canonicalized ( txid)  && !tx. is_coinbase ( )  { 
246-                     let  observed_in = ObservedIn :: Block ( height) ; 
247-                     self . mark_canonical ( txid,  tx,  CanonicalReason :: from_observed_in ( observed_in) ) ; 
248-                 } 
249-                 continue ; 
250-             } 
273+         if  !self . canonical_roots . is_empty ( )  { 
274+             let  topological_iter = TopologicalIteratorWithLevels :: new ( 
275+                 self . tx_graph , 
276+                 self . chain , 
277+                 self . chain_tip , 
278+                 & self . canonical_ancestors , 
279+                 self . canonical_roots . drain ( ..) . collect ( ) , 
280+             ) ; 
281+             self . queue . extend ( topological_iter) ; 
282+         } 
251283
252-             return  None ; 
284+         if  let  Some ( txid)  = self . queue . pop_front ( )  { 
285+             let  ( tx,  reason)  = self 
286+                 . canonical 
287+                 . get ( & txid) 
288+                 . cloned ( ) 
289+                 . expect ( "canonical reason must exist" ) ; 
290+             Some ( Ok ( ( txid,  tx,  reason) ) ) 
291+         }  else  { 
292+             None 
253293        } 
254294    } 
255295} 
@@ -342,3 +382,129 @@ impl<A: Clone> CanonicalReason<A> {
342382        } 
343383    } 
344384} 
385+ 
386+ struct  TopologicalIteratorWithLevels < ' a ,  A ,  C >  { 
387+     tx_graph :  & ' a  TxGraph < A > , 
388+     chain :  & ' a  C , 
389+     chain_tip :  BlockId , 
390+ 
391+     current_level :  Vec < Txid > , 
392+     next_level :  Vec < Txid > , 
393+ 
394+     adj_list :  HashMap < Txid ,  Vec < Txid > > , 
395+     parent_count :  HashMap < Txid ,  usize > , 
396+ 
397+     current_index :  usize , 
398+ } 
399+ 
400+ impl < ' a ,  A :  Anchor ,  C :  ChainOracle >  TopologicalIteratorWithLevels < ' a ,  A ,  C >  { 
401+     fn  new ( 
402+         tx_graph :  & ' a  TxGraph < A > , 
403+         chain :  & ' a  C , 
404+         chain_tip :  BlockId , 
405+         ancestors_by_txid :  & HashMap < Txid ,  Vec < Txid > > , 
406+         roots :  Vec < Txid > , 
407+     )  -> Self  { 
408+         let  mut  parent_count = HashMap :: new ( ) ; 
409+         let  mut  adj_list:  HashMap < Txid ,  Vec < Txid > >  = HashMap :: new ( ) ; 
410+ 
411+         for  ( txid,  ancestors)  in  ancestors_by_txid { 
412+             for  ancestor in  ancestors { 
413+                 adj_list. entry ( * ancestor) . or_default ( ) . push ( * txid) ; 
414+                 * parent_count. entry ( * txid) . or_insert ( 0 )  += 1 ; 
415+             } 
416+         } 
417+ 
418+         let  mut  current_level:  Vec < Txid >  = roots. to_vec ( ) ; 
419+ 
420+         // Sort the initial level by confirmation height 
421+         current_level. sort_by_key ( |& txid| { 
422+             let  tx_node = tx_graph. get_tx_node ( txid) . expect ( "tx should exist" ) ; 
423+             Self :: find_direct_anchor ( & tx_node,  chain,  chain_tip) 
424+                 . expect ( "should not fail" ) 
425+                 . map ( |anchor| anchor. confirmation_height_upper_bound ( ) ) 
426+                 . unwrap_or ( u32:: MAX ) 
427+         } ) ; 
428+ 
429+         Self  { 
430+             current_level, 
431+             next_level :  Vec :: new ( ) , 
432+             adj_list, 
433+             parent_count, 
434+             current_index :  0 , 
435+             tx_graph, 
436+             chain, 
437+             chain_tip, 
438+         } 
439+     } 
440+ 
441+     fn  find_direct_anchor ( 
442+         tx_node :  & TxNode < ' _ ,  Arc < Transaction > ,  A > , 
443+         chain :  & C , 
444+         chain_tip :  BlockId , 
445+     )  -> Result < Option < A > ,  C :: Error >  { 
446+         tx_node
447+             . anchors 
448+             . iter ( ) 
449+             . find_map ( |a| -> Option < Result < A ,  C :: Error > >  { 
450+                 match  chain. is_block_in_chain ( a. anchor_block ( ) ,  chain_tip)  { 
451+                     Ok ( Some ( true ) )  => Some ( Ok ( a. clone ( ) ) ) , 
452+                     Ok ( Some ( false ) )  | Ok ( None )  => None , 
453+                     Err ( err)  => Some ( Err ( err) ) , 
454+                 } 
455+             } ) 
456+             . transpose ( ) 
457+     } 
458+ 
459+     fn  advance_to_next_level ( & mut  self )  { 
460+         self . current_level  = core:: mem:: take ( & mut  self . next_level ) ; 
461+ 
462+         // Sort by confirmation height 
463+         self . current_level . sort_by_key ( |& txid| { 
464+             let  tx_node = self . tx_graph . get_tx_node ( txid) . expect ( "tx should exist" ) ; 
465+ 
466+             Self :: find_direct_anchor ( & tx_node,  self . chain ,  self . chain_tip ) 
467+                 . expect ( "should not fail" ) 
468+                 . map ( |anchor| anchor. confirmation_height_upper_bound ( ) ) 
469+                 . unwrap_or ( u32:: MAX ) 
470+         } ) ; 
471+ 
472+         self . current_index  = 0 ; 
473+     } 
474+ } 
475+ 
476+ impl < ' a ,  A :  Anchor ,  C :  ChainOracle >  Iterator  for  TopologicalIteratorWithLevels < ' a ,  A ,  C >  { 
477+     type  Item  = Txid ; 
478+ 
479+     fn  next ( & mut  self )  -> Option < Self :: Item >  { 
480+         // If we've exhausted the current level, move to next 
481+         if  self . current_index  >= self . current_level . len ( )  { 
482+             if  self . next_level . is_empty ( )  { 
483+                 return  None ; 
484+             } 
485+             self . advance_to_next_level ( ) ; 
486+         } 
487+ 
488+         let  current = self . current_level [ self . current_index ] ; 
489+         self . current_index  += 1 ; 
490+ 
491+         // If this is the last item in current level, prepare dependents for next level 
492+         if  self . current_index  == self . current_level . len ( )  { 
493+             // Process all dependents of all transactions in current level 
494+             for  & tx in  & self . current_level  { 
495+                 if  let  Some ( dependents)  = self . adj_list . get ( & tx)  { 
496+                     for  & dependent in  dependents { 
497+                         if  let  Some ( degree)  = self . parent_count . get_mut ( & dependent)  { 
498+                             * degree -= 1 ; 
499+                             if  * degree == 0  { 
500+                                 self . next_level . push ( dependent) ; 
501+                             } 
502+                         } 
503+                     } 
504+                 } 
505+             } 
506+         } 
507+ 
508+         Some ( current) 
509+     } 
510+ } 
0 commit comments