forked from xai-org/x-algorithm
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy paththunder_service.rs
More file actions
339 lines (296 loc) · 12.2 KB
/
Copy paththunder_service.rs
File metadata and controls
339 lines (296 loc) · 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
use lazy_static::lazy_static;
use log::{debug, info, warn};
use std::cmp::Reverse;
use std::collections::HashSet;
use std::sync::Arc;
use std::time::{Instant, SystemTime, UNIX_EPOCH};
use tokio::sync::Semaphore;
use tonic::{Request, Response, Status};
use xai_thunder_proto::{
GetInNetworkPostsRequest, GetInNetworkPostsResponse, LightPost,
in_network_posts_service_server::{InNetworkPostsService, InNetworkPostsServiceServer},
};
use crate::config::{
MAX_INPUT_LIST_SIZE, MAX_POSTS_TO_RETURN, MAX_VIDEOS_TO_RETURN,
};
use crate::metrics::{
GET_IN_NETWORK_POSTS_COUNT, GET_IN_NETWORK_POSTS_DURATION,
GET_IN_NETWORK_POSTS_DURATION_WITHOUT_STRATO, GET_IN_NETWORK_POSTS_EXCLUDED_SIZE,
GET_IN_NETWORK_POSTS_FOLLOWING_SIZE, GET_IN_NETWORK_POSTS_FOUND_FRESHNESS_SECONDS,
GET_IN_NETWORK_POSTS_FOUND_POSTS_PER_AUTHOR, GET_IN_NETWORK_POSTS_FOUND_REPLY_RATIO,
GET_IN_NETWORK_POSTS_FOUND_TIME_RANGE_SECONDS, GET_IN_NETWORK_POSTS_FOUND_UNIQUE_AUTHORS,
GET_IN_NETWORK_POSTS_MAX_RESULTS, IN_FLIGHT_REQUESTS, REJECTED_REQUESTS, Timer,
};
use crate::posts::post_store::PostStore;
use crate::strato_client::StratoClient;
pub struct ThunderServiceImpl {
/// PostStore for retrieving posts by user ID
post_store: Arc<PostStore>,
/// StratoClient for fetching following lists when not provided
strato_client: Arc<StratoClient>,
/// Semaphore to limit concurrent requests and prevent overload
request_semaphore: Arc<Semaphore>,
}
impl ThunderServiceImpl {
pub fn new(
post_store: Arc<PostStore>,
strato_client: Arc<StratoClient>,
max_concurrent_requests: usize,
) -> Self {
info!(
"Initializing ThunderService with max_concurrent_requests={}",
max_concurrent_requests
);
Self {
post_store,
strato_client,
request_semaphore: Arc::new(Semaphore::new(max_concurrent_requests)),
}
}
/// Create a gRPC server for this service
pub fn server(self) -> InNetworkPostsServiceServer<Self> {
InNetworkPostsServiceServer::new(self)
.accept_compressed(tonic::codec::CompressionEncoding::Zstd)
.send_compressed(tonic::codec::CompressionEncoding::Zstd)
}
/// Analyze found posts, calculate statistics, and report metrics
/// The `stage` parameter is used as a label to differentiate between stages (e.g., "post_store", "scored")
fn analyze_and_report_post_statistics(posts: &[LightPost], stage: &str) {
if posts.is_empty() {
debug!("[{}] No posts found for analysis", stage);
return;
}
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
// Time since most recent post
let time_since_most_recent = posts
.iter()
.map(|post| post.created_at)
.max()
.map(|most_recent| now - most_recent);
// Time since oldest post
let time_since_oldest = posts
.iter()
.map(|post| post.created_at)
.min()
.map(|oldest| now - oldest);
// Count replies vs original posts
let reply_count = posts.iter().filter(|post| post.is_reply).count();
let original_count = posts.len() - reply_count;
// Unique authors
let unique_authors: HashSet<_> = posts.iter().map(|post| post.author_id).collect();
let unique_author_count = unique_authors.len();
// Report metrics with stage label
if let Some(freshness) = time_since_most_recent {
GET_IN_NETWORK_POSTS_FOUND_FRESHNESS_SECONDS
.with_label_values(&[stage])
.observe(freshness as f64);
}
if let (Some(oldest), Some(newest)) = (time_since_oldest, time_since_most_recent) {
let time_range = oldest - newest;
GET_IN_NETWORK_POSTS_FOUND_TIME_RANGE_SECONDS
.with_label_values(&[stage])
.observe(time_range as f64);
}
let reply_ratio = reply_count as f64 / posts.len() as f64;
GET_IN_NETWORK_POSTS_FOUND_REPLY_RATIO
.with_label_values(&[stage])
.observe(reply_ratio);
GET_IN_NETWORK_POSTS_FOUND_UNIQUE_AUTHORS
.with_label_values(&[stage])
.observe(unique_author_count as f64);
if unique_author_count > 0 {
let posts_per_author = posts.len() as f64 / unique_author_count as f64;
GET_IN_NETWORK_POSTS_FOUND_POSTS_PER_AUTHOR
.with_label_values(&[stage])
.observe(posts_per_author);
}
// Log statistics with stage label
debug!(
"[{}] Post statistics: total={}, original={}, replies={}, unique_authors={}, posts_per_author={:.2}, reply_ratio={:.2}, time_since_most_recent={:?}s, time_range={:?}s",
stage,
posts.len(),
original_count,
reply_count,
unique_author_count,
if unique_author_count > 0 {
posts.len() as f64 / unique_author_count as f64
} else {
0.0
},
reply_ratio,
time_since_most_recent,
if let (Some(o), Some(n)) = (time_since_oldest, time_since_most_recent) {
Some(o - n)
} else {
None
}
);
}
}
#[tonic::async_trait]
impl InNetworkPostsService for ThunderServiceImpl {
/// Get posts from users in the network
async fn get_in_network_posts(
&self,
request: Request<GetInNetworkPostsRequest>,
) -> Result<Response<GetInNetworkPostsResponse>, Status> {
// Try to acquire semaphore permit without blocking
// If we're at capacity, reject immediately with RESOURCE_EXHAUSTED
let _permit = match self.request_semaphore.try_acquire() {
Ok(permit) => {
IN_FLIGHT_REQUESTS.inc();
permit
}
Err(_) => {
REJECTED_REQUESTS.inc();
return Err(Status::resource_exhausted(
"Server at capacity, please retry",
));
}
};
// Use a guard to decrement in_flight_requests when the request completes
struct InFlightGuard;
impl Drop for InFlightGuard {
fn drop(&mut self) {
IN_FLIGHT_REQUESTS.dec();
}
}
let _in_flight_guard = InFlightGuard;
// Start timer for total latency
let _total_timer = Timer::new(GET_IN_NETWORK_POSTS_DURATION.clone());
let req = request.into_inner();
if req.debug {
info!(
"Received GetInNetworkPosts request: user_id={}, following_count={}, exclude_tweet_ids={}",
req.user_id,
req.following_user_ids.len(),
req.exclude_tweet_ids.len(),
);
}
// If following_user_id list is empty, fetch it from Strato
let following_user_ids = if req.following_user_ids.is_empty() && req.debug {
info!(
"Following list is empty, fetching from Strato for user {}",
req.user_id
);
match self
.strato_client
.fetch_following_list(req.user_id as i64, MAX_INPUT_LIST_SIZE as i32)
.await
{
Ok(following_list) => {
info!(
"Fetched {} following users from Strato for user {}",
following_list.len(),
req.user_id
);
following_list.into_iter().map(|id| id as u64).collect()
}
Err(e) => {
warn!(
"Failed to fetch following list from Strato for user {}: {}",
req.user_id, e
);
return Err(Status::internal(format!(
"Failed to fetch following list: {}",
e
)));
}
}
} else {
req.following_user_ids
};
// Record metrics for request parameters
GET_IN_NETWORK_POSTS_FOLLOWING_SIZE.observe(following_user_ids.len() as f64);
GET_IN_NETWORK_POSTS_EXCLUDED_SIZE.observe(req.exclude_tweet_ids.len() as f64);
// Start timer for latency without strato call
let _processing_timer = Timer::new(GET_IN_NETWORK_POSTS_DURATION_WITHOUT_STRATO.clone());
// Default max_results if not specified
let max_results = if req.max_results > 0 {
req.max_results as usize
} else if req.is_video_request {
MAX_VIDEOS_TO_RETURN
} else {
MAX_POSTS_TO_RETURN
};
GET_IN_NETWORK_POSTS_MAX_RESULTS.observe(max_results as f64);
// Limit following_user_ids and exclude_tweet_ids to first K entries
let following_count = following_user_ids.len();
if following_count > MAX_INPUT_LIST_SIZE {
warn!(
"Limiting following_user_ids from {} to {} entries for user {}",
following_count, MAX_INPUT_LIST_SIZE, req.user_id
);
}
let following_user_ids: Vec<u64> = following_user_ids
.into_iter()
.take(MAX_INPUT_LIST_SIZE)
.collect();
let exclude_count = req.exclude_tweet_ids.len();
if exclude_count > MAX_INPUT_LIST_SIZE {
warn!(
"Limiting exclude_tweet_ids from {} to {} entries for user {}",
exclude_count, MAX_INPUT_LIST_SIZE, req.user_id
);
}
let exclude_tweet_ids: Vec<u64> = req
.exclude_tweet_ids
.into_iter()
.take(MAX_INPUT_LIST_SIZE)
.collect();
// Clone Arc references needed inside spawn_blocking
let post_store = Arc::clone(&self.post_store);
let request_user_id = req.user_id as i64;
// Use spawn_blocking to avoid blocking tokio's async runtime
let proto_posts = tokio::task::spawn_blocking(move || {
// Create exclude tweet IDs set for efficient filtering of previously seen posts
let exclude_tweet_ids: HashSet<i64> =
exclude_tweet_ids.iter().map(|&id| id as i64).collect();
let start_time = Instant::now();
// Fetch all posts (original + secondary) for the followed users
let all_posts: Vec<LightPost> = if req.is_video_request {
post_store.get_videos_by_users(
&following_user_ids,
&exclude_tweet_ids,
start_time,
request_user_id,
)
} else {
post_store.get_all_posts_by_users(
&following_user_ids,
&exclude_tweet_ids,
start_time,
request_user_id,
)
};
// Analyze posts and report statistics after querying post_store
ThunderServiceImpl::analyze_and_report_post_statistics(&all_posts, "retrieved");
let scored_posts = score_recent(all_posts, max_results);
// Analyze posts and report statistics after scoring
ThunderServiceImpl::analyze_and_report_post_statistics(&scored_posts, "scored");
scored_posts
})
.await
.map_err(|e| Status::internal(format!("Failed to process posts: {}", e)))?;
if req.debug {
info!(
"Returning {} posts for user {}",
proto_posts.len(),
req.user_id
);
}
// Record the number of posts returned
GET_IN_NETWORK_POSTS_COUNT.observe(proto_posts.len() as f64);
let response = GetInNetworkPostsResponse { posts: proto_posts };
Ok(Response::new(response))
}
}
/// Score posts by recency (created_at timestamp, newer posts first)
fn score_recent(mut light_posts: Vec<LightPost>, max_results: usize) -> Vec<LightPost> {
light_posts.sort_unstable_by_key(|post| Reverse(post.created_at));
// Limit to max results
light_posts.into_iter().take(max_results).collect()
}