Skip to content

fix: use mutiple connections for ws subscriptions #365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
153 changes: 116 additions & 37 deletions magicblock-account-updates/src/remote_account_updates_shard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ use std::{
};

use conjunto_transwise::RpcProviderConfig;
use futures_util::StreamExt;
use futures_util::{stream::FuturesUnordered, Stream, StreamExt};
use log::*;
use magicblock_metrics::metrics;
use solana_account_decoder::{UiAccountEncoding, UiDataSliceConfig};
use solana_account_decoder::{UiAccount, UiAccountEncoding, UiDataSliceConfig};
use solana_pubsub_client::nonblocking::pubsub_client::PubsubClient;
use solana_rpc_client_api::config::RpcAccountInfoConfig;
use solana_rpc_client_api::{config::RpcAccountInfoConfig, response::Response};
use solana_sdk::{
clock::{Clock, Slot},
commitment_config::CommitmentConfig,
Expand All @@ -24,6 +24,13 @@ use tokio::sync::mpsc::Receiver;
use tokio_stream::StreamMap;
use tokio_util::sync::CancellationToken;

type BoxFn = Box<
dyn FnOnce() -> Pin<Box<dyn Future<Output = ()> + Send + 'static>> + Send,
>;

type SubscriptionStream =
Pin<Box<dyn Stream<Item = Response<UiAccount>> + Send + 'static>>;

#[derive(Debug, Error)]
pub enum RemoteAccountUpdatesShardError {
#[error(transparent)]
Expand Down Expand Up @@ -65,11 +72,8 @@ impl RemoteAccountUpdatesShard {
// Create a pubsub client
info!("Shard {}: Starting", self.shard_id);
let ws_url = self.rpc_provider_config.ws_url();
let pubsub_client = PubsubClient::new(ws_url)
.await
.map_err(RemoteAccountUpdatesShardError::PubsubClientError)?;
// For every account, we only want the updates, not the actual content of the accounts
let rpc_account_info_config = Some(RpcAccountInfoConfig {
let config = RpcAccountInfoConfig {
commitment: self
.rpc_provider_config
.commitment()
Expand All @@ -80,21 +84,13 @@ impl RemoteAccountUpdatesShard {
length: 0,
}),
min_context_slot: None,
});
};
let mut pool = PubsubPool::new(ws_url, config).await?;
// Subscribe to the clock from the RPC (to figure out the latest slot)
let (mut clock_stream, clock_unsubscribe) = pubsub_client
.account_subscribe(&clock::ID, rpc_account_info_config.clone())
.await
.map_err(RemoteAccountUpdatesShardError::PubsubClientError)?;
let mut clock_stream = pool.subscribe(clock::ID).await?;
let mut clock_slot = 0;
// We'll store useful maps for each of the account subscriptions
let mut account_streams = StreamMap::new();
// rust compiler is not yet smart enough to figure out the exact type
type BoxFn = Box<
dyn FnOnce() -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
+ Send,
>;
let mut account_unsubscribes: HashMap<Pubkey, BoxFn> = HashMap::new();
const LOG_CLOCK_FREQ: u64 = 100;
let mut log_clock_count = 0;

Expand All @@ -118,19 +114,17 @@ impl RemoteAccountUpdatesShard {
} else {
warn!("Shard {}: Received empty clock data", self.shard_id);
}
self.try_to_override_last_known_update_slot(clock::ID, clock_slot);
}
// When we receive a message to start monitoring an account
Some((pubkey, unsub)) = self.monitoring_request_receiver.recv() => {
if unsub {
let Some(request) = account_unsubscribes.remove(&pubkey) else {
continue;
};
account_streams.remove(&pubkey);
metrics::set_subscriptions_count(account_streams.len(), &self.shard_id);
request().await;
pool.unsubscribe(&pubkey).await;
continue;
}
if account_unsubscribes.contains_key(&pubkey) {
if pool.subscribed(&pubkey) {
continue;
}
debug!(
Expand All @@ -139,12 +133,10 @@ impl RemoteAccountUpdatesShard {
pubkey,
clock_slot
);
let (stream, unsubscribe) = pubsub_client
.account_subscribe(&pubkey, rpc_account_info_config.clone())
.await
.map_err(RemoteAccountUpdatesShardError::PubsubClientError)?;
let stream = pool
.subscribe(pubkey)
.await?;
account_streams.insert(pubkey, stream);
account_unsubscribes.insert(pubkey, unsubscribe);
metrics::set_subscriptions_count(account_streams.len(), &self.shard_id);
self.try_to_override_first_subscribed_slot(pubkey, clock_slot);
}
Expand All @@ -164,17 +156,9 @@ impl RemoteAccountUpdatesShard {
}
}
// Cleanup all subscriptions and wait for proper shutdown
for (pubkey, account_unsubscribes) in account_unsubscribes.into_iter() {
info!(
"Shard {}: Account monitoring killed: {:?}",
self.shard_id, pubkey
);
account_unsubscribes().await;
}
clock_unsubscribe().await;
drop(account_streams);
drop(clock_stream);
pubsub_client.shutdown().await?;
pool.shutdown().await;
info!("Shard {}: Stopped", self.shard_id);
// Done
Ok(())
Expand Down Expand Up @@ -236,3 +220,98 @@ impl RemoteAccountUpdatesShard {
}
}
}

struct PubsubPool {
clients: Vec<PubSubConnection>,
unsubscribes: HashMap<Pubkey, (usize, BoxFn)>,
config: RpcAccountInfoConfig,
}

impl PubsubPool {
async fn new(
url: &str,
config: RpcAccountInfoConfig,
) -> Result<Self, RemoteAccountUpdatesShardError> {
// 8 is pretty much arbitrary, but a sane value for the number
// of connections per RPC upstream, we don't overcomplicate things
// here, as the whole cloning pipeline will be rewritten quite soon
const CONNECTIONS_PER_POOL: usize = 8;
let mut clients = Vec::with_capacity(CONNECTIONS_PER_POOL);
let mut connections: FuturesUnordered<_> = (0..CONNECTIONS_PER_POOL)
.map(|_| PubSubConnection::new(url))
.collect();
while let Some(c) = connections.next().await {
clients.push(c?);
}
Ok(Self {
clients,
unsubscribes: HashMap::new(),
config,
})
}

async fn subscribe(
&mut self,
pubkey: Pubkey,
) -> Result<SubscriptionStream, RemoteAccountUpdatesShardError> {
let (index, client) = self
.clients
.iter_mut()
.enumerate()
.min_by(|a, b| a.1.subs.cmp(&b.1.subs))
.expect("clients vec is always greater than 0");
let (stream, unsubscribe) = client
.inner
.account_subscribe(&pubkey, Some(self.config.clone()))
.await
.map_err(RemoteAccountUpdatesShardError::PubsubClientError)?;
client.subs += 1;
// SAFETY:
// we never drop the PubsubPool before the returned subscription stream
// so the lifetime of the stream can be safely extended to 'static
#[allow(clippy::missing_transmute_annotations)]
let stream = unsafe { std::mem::transmute(stream) };
self.unsubscribes.insert(pubkey, (index, unsubscribe));
Ok(stream)
}

async fn unsubscribe(&mut self, pubkey: &Pubkey) {
let Some((index, callback)) = self.unsubscribes.remove(pubkey) else {
return;
};
callback().await;
let Some(client) = self.clients.get_mut(index) else {
return;
};
client.subs = client.subs.saturating_sub(1);
}

fn subscribed(&mut self, pubkey: &Pubkey) -> bool {
self.unsubscribes.contains_key(pubkey)
}

async fn shutdown(&mut self) {
// Cleanup all subscriptions and wait for proper shutdown
for (pubkey, (_, callback)) in self.unsubscribes.drain() {
info!("Account monitoring killed: {:?}", pubkey);
callback().await;
}
for client in self.clients.drain(..) {
let _ = client.inner.shutdown().await;
}
}
}

struct PubSubConnection {
inner: PubsubClient,
subs: usize,
}

impl PubSubConnection {
async fn new(url: &str) -> Result<Self, RemoteAccountUpdatesShardError> {
let inner = PubsubClient::new(url)
.await
.map_err(RemoteAccountUpdatesShardError::PubsubClientError)?;
Ok(Self { inner, subs: 0 })
}
}
20 changes: 9 additions & 11 deletions magicblock-account-updates/tests/remote_account_updates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ use test_tools::skip_if_devnet_down;
use tokio::time::sleep;
use tokio_util::sync::CancellationToken;

fn setup() -> (
async fn setup() -> (
RemoteAccountUpdatesClient,
CancellationToken,
tokio::task::JoinHandle<()>,
) {
// Create account updates worker and client
let mut worker = RemoteAccountUpdatesWorker::new(
vec![RpcProviderConfig::devnet(), RpcProviderConfig::devnet()],
Duration::from_secs(1), // We constantly refresh stuff to make it struggle
Duration::from_secs(50 * 60), // We the same config as in production
);
let client = RemoteAccountUpdatesClient::new(&worker);
// Run the worker in a separate task
Expand All @@ -35,6 +35,8 @@ fn setup() -> (
.await
})
};
// wait a bit for websocket connections to establish
sleep(Duration::from_millis(5_000)).await;
// Ready to run
(client, cancellation_token, worker_handle)
}
Expand All @@ -43,11 +45,9 @@ fn setup() -> (
async fn test_devnet_monitoring_clock_sysvar_changes_over_time() {
skip_if_devnet_down!();
// Create account updates worker and client
let (client, cancellation_token, worker_handle) = setup();
let (client, cancellation_token, worker_handle) = setup().await;
// The clock will change every slots, perfect for testing updates
let sysvar_clock = clock::ID;
// Before starting the monitoring, we should know nothing about the clock
assert!(client.get_last_known_update_slot(&sysvar_clock).is_none());
// Start the monitoring
assert!(client
.ensure_account_monitoring(&sysvar_clock)
Expand All @@ -74,15 +74,14 @@ async fn test_devnet_monitoring_clock_sysvar_changes_over_time() {
async fn test_devnet_monitoring_multiple_accounts_at_the_same_time() {
skip_if_devnet_down!();
// Create account updates worker and client
let (client, cancellation_token, worker_handle) = setup();
let (client, cancellation_token, worker_handle) = setup().await;
// Devnet accounts to be monitored for this test
let sysvar_rent = rent::ID;
let sysvar_sh = slot_hashes::ID;
let sysvar_clock = clock::ID;
// We shouldnt known anything about the accounts until we subscribe
assert!(client.get_last_known_update_slot(&sysvar_rent).is_none());
assert!(client.get_last_known_update_slot(&sysvar_sh).is_none());
assert!(client.get_last_known_update_slot(&sysvar_clock).is_none());
// Start monitoring the accounts now
assert!(client.ensure_account_monitoring(&sysvar_rent).await.is_ok());
assert!(client.ensure_account_monitoring(&sysvar_sh).await.is_ok());
Expand All @@ -105,15 +104,14 @@ async fn test_devnet_monitoring_multiple_accounts_at_the_same_time() {
async fn test_devnet_monitoring_some_accounts_only() {
skip_if_devnet_down!();
// Create account updates worker and client
let (client, cancellation_token, worker_handle) = setup();
let (client, cancellation_token, worker_handle) = setup().await;
// Devnet accounts for this test
let sysvar_rent = rent::ID;
let sysvar_sh = slot_hashes::ID;
let sysvar_clock = clock::ID;
// We shouldnt known anything about the accounts until we subscribe
assert!(client.get_last_known_update_slot(&sysvar_rent).is_none());
assert!(client.get_last_known_update_slot(&sysvar_sh).is_none());
assert!(client.get_last_known_update_slot(&sysvar_clock).is_none());
// Start monitoring only some of the accounts
assert!(client.ensure_account_monitoring(&sysvar_rent).await.is_ok());
assert!(client.ensure_account_monitoring(&sysvar_sh).await.is_ok());
Expand All @@ -122,7 +120,7 @@ async fn test_devnet_monitoring_some_accounts_only() {
// Check that we detected the accounts changes only on the accounts we monitored
assert!(client.get_last_known_update_slot(&sysvar_rent).is_none()); // Rent doesn't change
assert!(client.get_last_known_update_slot(&sysvar_sh).is_some());
assert!(client.get_last_known_update_slot(&sysvar_clock).is_none());
assert!(client.get_last_known_update_slot(&sysvar_clock).is_some());
// Cleanup everything correctly
cancellation_token.cancel();
assert!(worker_handle.await.is_ok());
Expand All @@ -132,7 +130,7 @@ async fn test_devnet_monitoring_some_accounts_only() {
async fn test_devnet_monitoring_invalid_and_immutable_and_program_account() {
skip_if_devnet_down!();
// Create account updates worker and client
let (client, cancellation_token, worker_handle) = setup();
let (client, cancellation_token, worker_handle) = setup().await;
// Devnet accounts for this test (none of them should change)
let new_account = Keypair::new().pubkey();
let system_program = system_program::ID;
Expand Down