Skip to content

Commit 2493fa3

Browse files
committed
feat: batch updates
1 parent 53e50a6 commit 2493fa3

File tree

2 files changed

+131
-58
lines changed

2 files changed

+131
-58
lines changed

src/bin/pyth_reader.rs

Lines changed: 64 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,19 @@ use solana_metrics::datapoint_info;
1313
use solana_sdk::account::Account;
1414
use solana_sdk::commitment_config::CommitmentConfig;
1515
use solana_sdk::pubkey::Pubkey;
16+
use std::collections::HashMap;
1617
use std::collections::HashSet;
1718
use std::path::PathBuf;
1819
use std::str::FromStr;
20+
use std::sync::Arc;
1921
use std::time::Instant;
22+
use tokio::sync::Mutex;
2023
use tokio::task;
2124
use tokio::time::Duration;
2225
use tokio_stream::StreamExt;
2326
use tracing::{debug, error, info, warn};
2427
use tracing_subscriber::{fmt, EnvFilter};
2528
use url::Url;
26-
2729
#[derive(Debug, Serialize, Deserialize)]
2830
struct PriceUpdate {
2931
#[serde(rename = "type")]
@@ -33,13 +35,15 @@ struct PriceUpdate {
3335

3436
#[derive(Debug, Serialize, Deserialize)]
3537
struct PublisherPriceUpdate {
36-
#[serde(rename = "type")]
37-
update_type: String,
3838
publisher: String,
3939
feed_id: String,
40-
price_info: PriceInfo,
40+
slot: u64, // Add this field
41+
price: String,
4142
}
4243

44+
type PublisherKey = (String, String); // (feed_id, publisher)
45+
type PublisherBuffer = Arc<Mutex<HashMap<PublisherKey, PublisherPriceUpdate>>>;
46+
4347
#[derive(Debug, Serialize, Deserialize)]
4448
struct PriceFeed {
4549
id: String,
@@ -99,6 +103,7 @@ struct Args {
99103

100104
async fn fetch_price_updates(jetstream: jetstream::Context, config: &AppConfig) -> Result<()> {
101105
info!("Starting Pyth reader");
106+
let publisher_buffer: PublisherBuffer = Arc::new(Mutex::new(HashMap::new()));
102107
let client = PubsubClient::new(config.pyth.websocket_addr.as_str()).await?;
103108
info!(
104109
"Connected to Pyth WebSocket at {}",
@@ -123,6 +128,47 @@ async fn fetch_price_updates(jetstream: jetstream::Context, config: &AppConfig)
123128
let mut update_count = 0;
124129
let mut unique_price_feeds = HashSet::new();
125130
let mut last_report_time = Instant::now();
131+
let jetstream_clone = jetstream.clone();
132+
let buffer_clone = publisher_buffer.clone();
133+
let mut msg_id_counter = 0;
134+
tokio::spawn(async move {
135+
let mut interval = tokio::time::interval(Duration::from_millis(100));
136+
loop {
137+
interval.tick().await;
138+
139+
let updates: Vec<PublisherPriceUpdate> = {
140+
let mut buf = buffer_clone.lock().await;
141+
if buf.is_empty() {
142+
continue;
143+
}
144+
buf.drain().map(|(_, v)| v).collect()
145+
}; // <- lock released here
146+
147+
// Serialize as JSON array
148+
let body = match serde_json::to_string(&updates) {
149+
Ok(b) => b,
150+
Err(e) => {
151+
warn!("Failed to serialize batch of publisher updates: {}", e);
152+
continue;
153+
}
154+
};
155+
156+
// Use a random ID as Nats-Msg-Id for the batch
157+
let msg_id = format!("publisher_batch:{}", msg_id_counter);
158+
msg_id_counter += 1;
159+
let mut headers = HeaderMap::new();
160+
headers.insert("Nats-Msg-Id", msg_id.as_str());
161+
info!("body: {:#?},msg_id: {}", body.len(), msg_id);
162+
if let Err(e) = jetstream_clone
163+
.publish_with_headers("pyth.publisher.updates", headers, body.into())
164+
.await
165+
{
166+
warn!("Failed to publish batch publisher updates: {}", e);
167+
} else {
168+
debug!("Published {} publisher updates in a batch", updates.len());
169+
}
170+
}
171+
});
126172

127173
while let Some(update) = notif.next().await {
128174
debug!("Received price update");
@@ -133,13 +179,15 @@ async fn fetch_price_updates(jetstream: jetstream::Context, config: &AppConfig)
133179
continue;
134180
}
135181
};
182+
136183
let price_account: PythnetPriceAccount = match load_price_account(&account.data) {
137184
Ok(pyth_account) => *pyth_account,
138185
Err(_) => {
139186
debug!("Not a price account, skipping");
140187
continue;
141188
}
142189
};
190+
143191
// info!(
144192
// "Price Account: {:#?}, account: {:#?} \n\n",
145193
// price_account, account
@@ -207,7 +255,6 @@ async fn fetch_price_updates(jetstream: jetstream::Context, config: &AppConfig)
207255
});
208256

209257
for component in price_account.comp {
210-
let jetstream_clone = jetstream.clone();
211258
let publisher = component.publisher.to_string();
212259
let publisher_price_update_message_id = format!(
213260
"{}:{}:{}",
@@ -219,37 +266,19 @@ async fn fetch_price_updates(jetstream: jetstream::Context, config: &AppConfig)
219266
.insert("Nats-Msg-Id", publisher_price_update_message_id.as_str());
220267

221268
let publisher_price_update = PublisherPriceUpdate {
222-
update_type: "publisher_price_update".to_string(),
223-
feed_id: price_account.prod.to_string(),
224-
publisher: publisher,
225-
price_info: PriceInfo {
226-
price: price_account.agg.price.to_string(),
227-
conf: price_account.agg.conf.to_string(),
228-
expo: price_account.expo,
229-
publish_time: price_account.timestamp,
230-
slot: update.context.slot, // Add this field
231-
},
269+
feed_id: update.value.pubkey.to_string(),
270+
publisher: publisher.clone(),
271+
price: price_account.agg.price.to_string(),
272+
slot: update.context.slot, // Add this field
232273
};
233-
info!("Publisher price update: {:?}", publisher_price_update);
234-
let publisher_price_update_message =
235-
serde_json::to_string(&publisher_price_update)?;
236-
237-
task::spawn(async move {
238-
match jetstream_clone
239-
.publish_with_headers(
240-
"pyth.publisher.updates",
241-
publisher_price_updates,
242-
publisher_price_update_message.into(),
243-
)
244-
.await
245-
{
246-
Ok(_) => debug!(
247-
"Published publisher price update to JetStream with ID: {}",
248-
publisher_price_update_message_id
249-
),
250-
Err(e) => warn!("Failed to publish price update to JetStream: {}", e),
251-
}
252-
});
274+
if publisher_price_update.feed_id == "7jAVut34sgRj6erznsYvLYvjc9GJwXTpN88ThZSDJ65G"
275+
&& publisher == "6DNocjFJjocPLZnKBZyEJAC5o2QaiT5Mx8AkphfxDm5i"
276+
{
277+
info!("publisher_price_update: {:#?}", publisher_price_update);
278+
}
279+
let key = (publisher_price_update.feed_id.clone(), publisher);
280+
let mut buf = publisher_buffer.lock().await;
281+
buf.insert(key, publisher_price_update);
253282
}
254283
update_count += 1;
255284
unique_price_feeds.insert(price_account.prod);

src/bin/websocket_server.rs

Lines changed: 67 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use hyper::{Request, Response};
1111
use hyper_util::rt::TokioIo;
1212
use pyth_stream::utils::setup_jetstream;
1313
use serde::{Deserialize, Serialize};
14+
use serde_json::json;
1415
use std::clone::Clone;
1516
use std::collections::{HashMap, HashSet};
1617
use std::net::SocketAddr;
@@ -328,11 +329,10 @@ struct PriceUpdate {
328329

329330
#[derive(Debug, Serialize, Deserialize)]
330331
struct PublisherPriceUpdate {
331-
#[serde(rename = "type")]
332-
update_type: String,
333332
publisher: String,
334333
feed_id: String,
335-
price_info: PriceInfo,
334+
slot: u64,
335+
price: String,
336336
}
337337

338338
#[derive(Debug, Serialize, Deserialize)]
@@ -379,34 +379,78 @@ async fn handle_nats_publisher_updates_messages(
379379
while let Some(msg) = messages.next().await {
380380
match msg {
381381
Ok(msg) => {
382-
debug!("Received NATS message {:?}", msg.payload);
383-
let mut publisher_price_update: PublisherPriceUpdate =
382+
let updates: Vec<PublisherPriceUpdate> =
384383
match serde_json::from_slice(&msg.payload) {
385-
Ok(update) => update,
384+
Ok(u) => u,
386385
Err(e) => {
387-
warn!(error = %e, "Failed to parse publisher price update");
386+
warn!(error = %e, "Failed to parse publisher price update batch");
388387
continue;
389388
}
390389
};
391-
let clients = clients.lock().await;
390+
debug!("Parsed {} publisher updates in batch", updates.len());
391+
// Build per-client payloads while holding the lock,
392+
// but DO NOT send while holding it.
393+
let mut to_send: Vec<(String, mpsc::UnboundedSender<Message>, String)> =
394+
Vec::new();
395+
{
396+
let clients = clients.lock().await;
397+
398+
for (client_addr, client_data) in clients.iter() {
399+
// Filter only updates the client cares about
400+
let filtered: Vec<_> = updates
401+
.iter()
402+
.filter(|u| {
403+
client_data.publisher_subscriptions.contains(&u.publisher)
404+
})
405+
.map(|u| {
406+
if u.feed_id == "7jAVut34sgRj6erznsYvLYvjc9GJwXTpN88ThZSDJ65G" {
407+
info!("publisher_price_update: {:#?}", u);
408+
}
409+
json!({
410+
"publisher": u.publisher,
411+
"feed_id": u.feed_id,
412+
"slot": u.slot,
413+
"price": u.price,
414+
})
415+
})
416+
.collect();
417+
if filtered.is_empty() {
418+
continue;
419+
}
392420

393-
for (client_addr, client_data) in clients.iter() {
421+
info!(
422+
"Preparing batch for client {} ({} updates, subs={:?})",
423+
client_addr,
424+
filtered.len(),
425+
client_data.publisher_subscriptions
426+
);
427+
428+
let batch_json = serde_json::to_string(&json!({
429+
"type": "publisher_price_update",
430+
"updates": filtered
431+
}))
432+
.unwrap();
433+
434+
// Clone the sender so we can drop the lock before sending
435+
to_send.push((
436+
client_addr.clone(),
437+
client_data.sender.clone(),
438+
batch_json,
439+
));
440+
}
441+
} // lock dropped here
442+
443+
// Now send the prepared batches
444+
for (client_addr, sender, batch_json) in to_send {
394445
info!(
395-
"Number of connected clients: {:#?}, {:#?}",
396-
client_data.publisher_subscriptions, publisher_price_update,
446+
"Sending {}-byte batch to client {}",
447+
batch_json.len(),
448+
client_addr
397449
);
398-
if client_data
399-
.publisher_subscriptions
400-
.contains(&publisher_price_update.publisher)
401-
{
402-
info!("Sending update to client: {}", client_addr);
403-
let update_json =
404-
serde_json::to_string(&publisher_price_update).unwrap();
405-
if let Err(e) = client_data.sender.send(Message::Text(update_json)) {
406-
warn!(client_addr = %client_addr, error = %e, "Failed to send publisher price update to client");
407-
} else {
408-
info!(client_addr = %client_addr, "Successfully sent publisher price update to client");
409-
}
450+
if let Err(e) = sender.send(Message::Text(batch_json)) {
451+
warn!(client_addr = %client_addr, error = %e, "Failed to send publisher batch");
452+
} else {
453+
info!(client_addr = %client_addr, "Successfully sent publisher batch");
410454
}
411455
}
412456
}

0 commit comments

Comments
 (0)