diff --git a/Cargo.toml b/Cargo.toml index f79a5ede24..e0a5e0c517 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bdk" -version = "0.30.2" +version = "0.30.4" authors = ["Alekos Filini ", "Riccardo Casatta "] homepage = "https://bitcoindevkit.org" repository = "https://github.com/bitcoindevkit/bdk" diff --git a/src/blockchain/any.rs b/src/blockchain/any.rs index 38b5f117f9..fbc3a0ed06 100644 --- a/src/blockchain/any.rs +++ b/src/blockchain/any.rs @@ -129,6 +129,21 @@ impl GetBlockHash for AnyBlockchain { #[maybe_async] impl WalletSync for AnyBlockchain { + fn wallet_setup_with_control( + &self, + database: &RefCell, + progress_update: Box, + control: SyncControl, + ) -> Result<(), Error> { + maybe_await!(impl_inner_method!( + self, + wallet_setup_with_control, + database, + progress_update, + control + )) + } + fn wallet_sync( &self, database: &RefCell, @@ -154,6 +169,21 @@ impl WalletSync for AnyBlockchain { progress_update )) } + + fn wallet_sync_with_control( + &self, + database: &RefCell, + progress_update: Box, + control: SyncControl, + ) -> Result<(), Error> { + maybe_await!(impl_inner_method!( + self, + wallet_sync_with_control, + database, + progress_update, + control + )) + } } impl_from!(boxed electrum::ElectrumBlockchain, AnyBlockchain, Electrum, #[cfg(feature = "electrum")]); diff --git a/src/blockchain/electrum.rs b/src/blockchain/electrum.rs index fda3258570..89a8f54f80 100644 --- a/src/blockchain/electrum.rs +++ b/src/blockchain/electrum.rs @@ -114,12 +114,14 @@ impl GetBlockHash for ElectrumBlockchain { } } -impl WalletSync for ElectrumBlockchain { - fn wallet_setup( +impl ElectrumBlockchain { + fn wallet_setup_impl( &self, database: &RefCell, - _progress_update: Box, + control: &SyncControl, ) -> Result<(), Error> { + control.check_cancelled()?; + let mut database = database.borrow_mut(); let database = database.deref_mut(); let mut request = script_sync::start(database, self.stop_gap)?; @@ -138,6 +140,8 @@ impl WalletSync for ElectrumBlockchain { let batch_update = loop { request = match request { Request::Script(script_req) => { + control.check_cancelled()?; + let scripts = script_req.request().take(chunk_size); let txids_per_script: Vec> = self .client @@ -164,6 +168,8 @@ impl WalletSync for ElectrumBlockchain { } Request::Conftime(conftime_req) => { + control.check_cancelled()?; + // collect up to chunk_size heights to fetch from electrum let needs_block_height = conftime_req .request() @@ -202,8 +208,10 @@ impl WalletSync for ElectrumBlockchain { conftime_req.satisfy(conftimes)? } Request::Tx(tx_req) => { + control.check_cancelled()?; + let needs_full = tx_req.request().take(chunk_size); - tx_cache.save_txs(needs_full.clone())?; + tx_cache.save_txs(needs_full.clone(), control)?; let full_transactions = needs_full .map(|txid| tx_cache.get(*txid).ok_or_else(electrum_goof)) .collect::, _>>()?; @@ -213,7 +221,7 @@ impl WalletSync for ElectrumBlockchain { .filter(|input| !input.previous_output.is_null()) .map(|input| &input.previous_output.txid) }); - tx_cache.save_txs(input_txs)?; + tx_cache.save_txs(input_txs, control)?; let full_details = full_transactions .into_iter() @@ -247,11 +255,40 @@ impl WalletSync for ElectrumBlockchain { } }; + control.check_cancelled()?; database.commit_batch(batch_update)?; Ok(()) } } +impl WalletSync for ElectrumBlockchain { + fn wallet_setup( + &self, + database: &RefCell, + _progress_update: Box, + ) -> Result<(), Error> { + self.wallet_setup_impl(database, &SyncControl::default()) + } + + fn wallet_setup_with_control( + &self, + database: &RefCell, + _progress_update: Box, + control: SyncControl, + ) -> Result<(), Error> { + self.wallet_setup_impl(database, &control) + } + + fn wallet_sync_with_control( + &self, + database: &RefCell, + _progress_update: Box, + control: SyncControl, + ) -> Result<(), Error> { + self.wallet_setup_impl(database, &control) + } +} + struct TxCache<'a, 'b, D> { db: &'a D, client: &'b Client, @@ -266,9 +303,14 @@ impl<'a, 'b, D: Database> TxCache<'a, 'b, D> { cache: HashMap::default(), } } - fn save_txs<'c>(&mut self, txids: impl Iterator) -> Result<(), Error> { + fn save_txs<'c>( + &mut self, + txids: impl Iterator, + control: &SyncControl, + ) -> Result<(), Error> { let mut need_fetch = vec![]; for txid in txids { + control.check_cancelled()?; if self.cache.contains_key(txid) { continue; } else if let Some(transaction) = self.db.get_raw_tx(txid)? { @@ -282,6 +324,7 @@ impl<'a, 'b, D: Database> TxCache<'a, 'b, D> { // of transactions at once, which creates enormous memory pressure. By chunking the batch // into more reasonably sized sub-queries, we allow time for memory to be freed. for chunk in need_fetch.chunks(1000) { + control.check_cancelled()?; let txs = self .client .batch_transaction_get(chunk) diff --git a/src/blockchain/mod.rs b/src/blockchain/mod.rs index 13c0b2a428..73163b7211 100644 --- a/src/blockchain/mod.rs +++ b/src/blockchain/mod.rs @@ -19,8 +19,10 @@ use std::cell::RefCell; use std::collections::HashSet; use std::ops::Deref; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc::{channel, Receiver, Sender}; use std::sync::Arc; +use std::time::Instant; use bitcoin::{BlockHash, Transaction, Txid}; @@ -119,6 +121,35 @@ pub trait GetBlockHash { } /// Trait for blockchains that can sync by updating the database directly. +#[derive(Debug, Clone, Default)] +pub struct SyncControl { + /// Shared cancellation flag observed by sync checkpoints. + pub cancel: Option>, + /// Deadline after which sync should abort at the next checkpoint. + pub deadline: Option, +} + +impl SyncControl { + /// Return true if cancellation should occur. + pub fn is_cancelled(&self) -> bool { + self.cancel + .as_ref() + .map_or(false, |cancel| cancel.load(Ordering::Relaxed)) + || self + .deadline + .map_or(false, |deadline| Instant::now() >= deadline) + } + + /// Return an error when cancellation is requested. + pub fn check_cancelled(&self) -> Result<(), Error> { + if self.is_cancelled() { + return Err(Error::SyncCancelled); + } + + Ok(()) + } +} + #[maybe_async] pub trait WalletSync { /// Setup the backend and populate the internal database for the first time @@ -138,6 +169,18 @@ pub trait WalletSync { progress_update: Box, ) -> Result<(), Error>; + /// Like [`Self::wallet_setup`], but with cooperative cancellation support. + /// + /// Default implementation preserves current behavior and ignores `control`. + fn wallet_setup_with_control( + &self, + database: &RefCell, + progress_update: Box, + _control: SyncControl, + ) -> Result<(), Error> { + maybe_await!(self.wallet_setup(database, progress_update)) + } + /// If not overridden, it defaults to calling [`Self::wallet_setup`] internally. /// /// This method should implement the logic required to iterate over the list of the wallet's @@ -162,6 +205,18 @@ pub trait WalletSync { ) -> Result<(), Error> { maybe_await!(self.wallet_setup(database, progress_update)) } + + /// Like [`Self::wallet_sync`], but with cooperative cancellation support. + /// + /// Default implementation preserves current behavior and ignores `control`. + fn wallet_sync_with_control( + &self, + database: &RefCell, + progress_update: Box, + _control: SyncControl, + ) -> Result<(), Error> { + maybe_await!(self.wallet_sync(database, progress_update)) + } } /// Trait for [`Blockchain`] types that can be created given a configuration @@ -381,6 +436,17 @@ impl WalletSync for Arc { maybe_await!(self.deref().wallet_setup(database, progress_update)) } + fn wallet_setup_with_control( + &self, + database: &RefCell, + progress_update: Box, + control: SyncControl, + ) -> Result<(), Error> { + maybe_await!(self + .deref() + .wallet_setup_with_control(database, progress_update, control)) + } + fn wallet_sync( &self, database: &RefCell, @@ -388,4 +454,15 @@ impl WalletSync for Arc { ) -> Result<(), Error> { maybe_await!(self.deref().wallet_sync(database, progress_update)) } + + fn wallet_sync_with_control( + &self, + database: &RefCell, + progress_update: Box, + control: SyncControl, + ) -> Result<(), Error> { + maybe_await!(self + .deref() + .wallet_sync_with_control(database, progress_update, control)) + } } diff --git a/src/error.rs b/src/error.rs index 29af11ebfa..03639e91fc 100644 --- a/src/error.rs +++ b/src/error.rs @@ -130,6 +130,8 @@ pub enum Error { /// [`crate::blockchain::WalletSync`] sync attempt failed due to missing scripts in cache which /// are needed to satisfy `stop_gap`. MissingCachedScripts(MissingCachedScripts), + /// [`crate::blockchain::WalletSync`] sync was cancelled by cooperative control. + SyncCancelled, #[cfg(feature = "electrum")] /// Electrum client error @@ -258,6 +260,7 @@ impl fmt::Display for Error { Self::MissingCachedScripts(missing_cached_scripts) => { write!(f, "Missing cached scripts: {:?}", missing_cached_scripts) } + Self::SyncCancelled => write!(f, "Wallet sync cancelled"), #[cfg(feature = "electrum")] Self::Electrum(err) => write!(f, "Electrum client error: {}", err), #[cfg(feature = "esplora")] diff --git a/src/wallet/mod.rs b/src/wallet/mod.rs index 4e10f015cf..9201f3935e 100644 --- a/src/wallet/mod.rs +++ b/src/wallet/mod.rs @@ -57,7 +57,7 @@ use signer::{SignOptions, SignerOrdering, SignersContainer, TransactionSigner}; use tx_builder::{BumpFee, CreateTx, FeePolicy, TxBuilder, TxParams}; use utils::{check_nsequence_rbf, After, Older, SecpCtx}; -use crate::blockchain::{GetHeight, NoopProgress, Progress, WalletSync}; +use crate::blockchain::{GetHeight, NoopProgress, Progress, SyncControl, WalletSync}; use crate::database::memory::MemoryDatabase; use crate::database::{AnyDatabase, BatchDatabase, BatchOperations, DatabaseUtils, SyncTime}; use crate::descriptor::checksum::calc_checksum_bytes_internal; @@ -168,6 +168,8 @@ impl fmt::Display for AddressInfo { pub struct SyncOptions { /// The progress tracker which may be informed when progress is made. pub progress: Option>, + /// Optional cooperative control for cancellation and deadlines. + pub sync_control: Option, } impl Wallet @@ -1752,11 +1754,17 @@ where blockchain: &B, sync_opts: SyncOptions, ) -> Result<(), Error> { + let SyncOptions { + progress, + sync_control, + } = sync_opts; + let control = sync_control.unwrap_or_default(); + debug!("Begin sync..."); - // TODO: for the next runs, we cannot reuse the `sync_opts.progress` object due to trait + // TODO: for the next runs, we cannot reuse the `progress` object due to trait // restrictions - let mut progress_iter = sync_opts.progress.into_iter(); + let mut progress_iter = progress.into_iter(); let mut new_progress = || { progress_iter .next() @@ -1780,9 +1788,17 @@ where for _ in 0..max_rounds { let sync_res = if run_setup { - maybe_await!(blockchain.wallet_setup(&self.database, new_progress())) + maybe_await!(blockchain.wallet_setup_with_control( + &self.database, + new_progress(), + control.clone(), + )) } else { - maybe_await!(blockchain.wallet_sync(&self.database, new_progress())) + maybe_await!(blockchain.wallet_sync_with_control( + &self.database, + new_progress(), + control.clone(), + )) }; // If the error is the special `MissingCachedScripts` error, we return the number of @@ -1821,6 +1837,18 @@ where Ok(()) } + /// Sync the internal database with the blockchain using cooperative cancellation control. + #[maybe_async] + pub fn sync_with_control( + &self, + blockchain: &B, + mut sync_opts: SyncOptions, + control: SyncControl, + ) -> Result<(), Error> { + sync_opts.sync_control = Some(control); + maybe_await!(self.sync(blockchain, sync_opts)) + } + /// Return the checksum of the public descriptor associated to `keychain` /// /// Internally calls [`Self::get_descriptor_for_keychain`] to fetch the right descriptor @@ -1911,8 +1939,12 @@ pub fn get_funded_wallet( pub(crate) mod test { use assert_matches::assert_matches; use bitcoin::{absolute, blockdata::script::PushBytes, psbt, Network, Sequence}; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + use crate::blockchain::{GetHeight, SyncControl, WalletSync}; use crate::database::Database; + use crate::database::{BatchDatabase, MemoryDatabase}; use crate::types::KeychainKind; use super::*; @@ -1927,6 +1959,59 @@ pub(crate) mod test { // OP_PUSH. const P2WPKH_FAKE_WITNESS_SIZE: usize = 106; + #[derive(Default)] + struct CountingBlockchain { + setup_with_control_calls: Arc, + sync_calls: Arc, + sync_with_control_calls: Arc, + } + + impl GetHeight for CountingBlockchain { + fn get_height(&self) -> Result { + Ok(100) + } + } + + impl WalletSync for CountingBlockchain { + fn wallet_setup( + &self, + _database: &RefCell, + _progress_update: Box, + ) -> Result<(), Error> { + self.setup_with_control_calls.fetch_add(1, Ordering::SeqCst); + Ok(()) + } + + fn wallet_sync( + &self, + _database: &RefCell, + _progress_update: Box, + ) -> Result<(), Error> { + self.sync_calls.fetch_add(1, Ordering::SeqCst); + Ok(()) + } + + fn wallet_setup_with_control( + &self, + _database: &RefCell, + _progress_update: Box, + control: SyncControl, + ) -> Result<(), Error> { + self.setup_with_control_calls.fetch_add(1, Ordering::SeqCst); + control.check_cancelled() + } + + fn wallet_sync_with_control( + &self, + _database: &RefCell, + _progress_update: Box, + control: SyncControl, + ) -> Result<(), Error> { + self.sync_with_control_calls.fetch_add(1, Ordering::SeqCst); + control.check_cancelled() + } + } + #[test] fn test_descriptor_checksum() { let (wallet, _, _) = get_funded_wallet(get_test_wpkh()); @@ -1938,6 +2023,97 @@ pub(crate) mod test { ); } + #[test] + fn test_sync_with_control_uses_wallet_sync_with_control() { + let wallet = Wallet::new( + get_test_wpkh(), + None, + Network::Regtest, + MemoryDatabase::new(), + ) + .unwrap(); + let blockchain = CountingBlockchain::default(); + + wallet.sync(&blockchain, SyncOptions::default()).unwrap(); + wallet + .sync_with_control(&blockchain, SyncOptions::default(), SyncControl::default()) + .unwrap(); + wallet + .sync( + &blockchain, + SyncOptions { + progress: None, + sync_control: Some(SyncControl::default()), + }, + ) + .unwrap(); + + assert_eq!( + blockchain.setup_with_control_calls.load(Ordering::SeqCst), + 0 + ); + assert_eq!(blockchain.sync_with_control_calls.load(Ordering::SeqCst), 3); + assert_eq!(blockchain.sync_calls.load(Ordering::SeqCst), 0); + } + + #[test] + fn test_sync_with_control_cancellation_propagates() { + let wallet = Wallet::new( + get_test_wpkh(), + None, + Network::Regtest, + MemoryDatabase::new(), + ) + .unwrap(); + let blockchain = CountingBlockchain::default(); + + wallet.sync(&blockchain, SyncOptions::default()).unwrap(); + + let cancel = Arc::new(std::sync::atomic::AtomicBool::new(true)); + let err = wallet + .sync( + &blockchain, + SyncOptions { + progress: None, + sync_control: Some(SyncControl { + cancel: Some(cancel), + deadline: None, + }), + }, + ) + .unwrap_err(); + + assert_matches!(err, Error::SyncCancelled); + } + + #[test] + fn test_sync_with_control_cancellation_propagates_on_setup_path() { + let wallet = Wallet::new( + "wpkh(tpubEBr4i6yk5nf5DAaJpsi9N2pPYBeJ7fZ5Z9rmN4977iYLCGco1VyjB9tvvuvYtfZzjD5A8igzgw3HeWeeKFmanHYqksqZXYXGsw5zjnj7KM9/*)", + None, + Network::Testnet, + MemoryDatabase::new(), + ) + .unwrap(); + let blockchain = CountingBlockchain::default(); + + let cancel = Arc::new(std::sync::atomic::AtomicBool::new(true)); + let err = wallet + .sync( + &blockchain, + SyncOptions { + progress: None, + sync_control: Some(SyncControl { + cancel: Some(cancel), + deadline: None, + }), + }, + ) + .unwrap_err(); + + assert_matches!(err, Error::SyncCancelled); + } + #[test] fn test_db_checksum() { let (wallet, _, _) = get_funded_wallet(get_test_wpkh());