Skip to content

Commit 273f741

Browse files
committed
add boundary collector
1 parent d93b6fc commit 273f741

File tree

23 files changed

+1640
-1337
lines changed

23 files changed

+1640
-1337
lines changed

engine/Cargo.lock

-8
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

engine/Cargo.toml

-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ members = [
1010
"language_client_python",
1111
"language_client_ruby/ext/ruby_ffi",
1212
"language_client_typescript",
13-
"boundary-collector",
1413
"sandbox",
1514
]
1615
default-members = [
@@ -27,7 +26,6 @@ default-members = [
2726
"language_client_python",
2827
"language_client_ruby/ext/ruby_ffi",
2928
"language_client_typescript",
30-
"boundary-collector",
3129
]
3230

3331
[workspace.dependencies]

engine/baml-lib/baml-core/src/ir/repr.rs

+4
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ pub struct Walker<'ir, I> {
6060
}
6161

6262
impl IntermediateRepr {
63+
pub fn create_hash(&self) -> String {
64+
todo!("create hash");
65+
}
66+
6367
pub fn create_empty() -> IntermediateRepr {
6468
IntermediateRepr {
6569
enums: vec![],
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
1+
use anyhow::Result;
2+
use internal_baml_diagnostics::SourceFile;
3+
use std::path::PathBuf;
4+
use std::sync::Arc;
5+
use crate::{validate, validate_single_file};
6+
use crate::ir::IntermediateRepr;
17
use super::*;
28

39
#[test]
4-
fn test_signature() {
5-
const BAML_SRC: &str = include_str!("../../test_data/test.baml");
6-
let db = ParserDatabase::new(BamlSource::new(BamlSourceType::String(BAML_SRC.to_string())));
7-
let ir = BamlIr::new(db);
8-
let signature = ir.signature();
9-
println!("{}", signature);
10-
}
10+
fn test_signature() -> Result<()> {
11+
const BAML_SRC: &str = include_str!("test_data/test.baml");
12+
let contents = vec![SourceFile::new_allocated(PathBuf::from("test.baml"), Arc::from(BAML_SRC))];
13+
let mut schema = validate(&PathBuf::from("."), contents);
14+
schema.diagnostics.to_result()?;
15+
16+
let ir = IntermediateRepr::from_parser_database(&schema.db, schema.configuration)?;
17+
let signature = ir.create_baml_hash();
18+
Ok(())
19+
}

engine/baml-runtime/Cargo.toml

-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ anyhow.workspace = true
1515
[lints.rust]
1616
dead_code = "allow"
1717
elided_named_lifetimes = "deny"
18-
unused_imports = "allow"
1918
unused_variables = "allow"
2019

2120
[dependencies]

engine/baml-runtime/src/lib.rs

+5-7
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,8 @@ use internal_baml_core::configuration::Generator;
3838
use internal_baml_core::configuration::GeneratorOutputType;
3939
use on_log_event::LogEventCallbackSync;
4040
use runtime::InternalBamlRuntime;
41-
use serde_json::json;
41+
use tracingv2::storage::storage::FunctionTrackerTrait;
4242
use std::sync::OnceLock;
43-
use tracingv2::storage::storage::Collector;
44-
use tracingv2::storage::storage::BAML_TRACER;
4543

4644
#[cfg(not(target_arch = "wasm32"))]
4745
pub use cli::RuntimeCliDefaults;
@@ -226,7 +224,7 @@ impl BamlRuntime {
226224
test_name: &str,
227225
ctx: &RuntimeContextManager,
228226
on_event: Option<F>,
229-
collector: Option<Arc<Collector>>,
227+
collector: Option<Arc<Box<dyn FunctionTrackerTrait>>>,
230228
) -> (Result<TestResponse>, Option<uuid::Uuid>)
231229
where
232230
F: Fn(FunctionResult),
@@ -333,7 +331,7 @@ impl BamlRuntime {
333331
ctx: &RuntimeContextManager,
334332
tb: Option<&TypeBuilder>,
335333
cb: Option<&ClientRegistry>,
336-
collectors: Option<Vec<Arc<Collector>>>,
334+
collectors: Option<Vec<Arc<Box<dyn FunctionTrackerTrait>>>>,
337335
) -> (Result<FunctionResult>, Option<uuid::Uuid>) {
338336
let fut = self.call_function(function_name, params, ctx, tb, cb, collectors);
339337
self.async_runtime.block_on(fut)
@@ -346,7 +344,7 @@ impl BamlRuntime {
346344
ctx: &RuntimeContextManager,
347345
tb: Option<&TypeBuilder>,
348346
cb: Option<&ClientRegistry>,
349-
collectors: Option<Vec<Arc<Collector>>>,
347+
collectors: Option<Vec<Arc<Box<dyn FunctionTrackerTrait>>>>,
350348
) -> (Result<FunctionResult>, Option<uuid::Uuid>) {
351349
log::trace!("Calling function: {}", function_name);
352350
let span = self.tracer.start_span(&function_name, ctx, params);
@@ -391,7 +389,7 @@ impl BamlRuntime {
391389
ctx: &RuntimeContextManager,
392390
tb: Option<&TypeBuilder>,
393391
cb: Option<&ClientRegistry>,
394-
collectors: Option<Vec<Arc<Collector>>>,
392+
collectors: Option<Vec<Arc<Box<dyn FunctionTrackerTrait>>>>,
395393
) -> Result<FunctionResultStream> {
396394
self.inner.stream_function_impl(
397395
function_name,

engine/baml-runtime/src/runtime/runtime_interface.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::{collections::HashMap, path::PathBuf, sync::Arc};
33
use super::InternalBamlRuntime;
44
use crate::internal::llm_client::traits::WithClientProperties;
55
use crate::internal::llm_client::LLMResponse;
6-
use crate::tracingv2::storage::storage::{Collector, BAML_TRACER};
6+
use crate::tracingv2::storage::storage::{FunctionTrackerTrait, BAML_TRACER};
77
use crate::type_builder::TypeBuilder;
88
use crate::RuntimeContextManager;
99
use crate::{
@@ -488,7 +488,7 @@ impl RuntimeInterface for InternalBamlRuntime {
488488
tracer: Arc<BamlTracer>,
489489
ctx: RuntimeContext,
490490
#[cfg(not(target_arch = "wasm32"))] tokio_runtime: Arc<tokio::runtime::Runtime>,
491-
collectors: Vec<Arc<Collector>>,
491+
collectors: Vec<Arc<Box<dyn FunctionTrackerTrait>>>,
492492
) -> Result<FunctionResultStream> {
493493
let func = self.get_function(&function_name, &ctx)?;
494494
let renderer = PromptRenderer::from_function(&func, self.ir(), &ctx)?;

engine/baml-runtime/src/runtime_interface.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::{collections::HashMap, sync::Arc};
99
use crate::internal::llm_client::llm_provider::LLMProvider;
1010
use crate::internal::llm_client::orchestrator::{OrchestrationScope, OrchestratorNode};
1111
use crate::tracing::{BamlTracer, TracingSpan};
12-
use crate::tracingv2::storage::storage::Collector;
12+
use crate::tracingv2::storage::storage::FunctionTrackerTrait;
1313
use crate::type_builder::TypeBuilder;
1414
use crate::types::on_log_event::LogEventCallbackSync;
1515
use crate::{
@@ -47,7 +47,7 @@ pub trait RuntimeInterface {
4747
tracer: Arc<BamlTracer>,
4848
ctx: RuntimeContext,
4949
#[cfg(not(target_arch = "wasm32"))] tokio_runtime: Arc<tokio::runtime::Runtime>,
50-
collectors: Vec<Arc<Collector>>,
50+
collectors: Vec<Arc<Box<dyn FunctionTrackerTrait>>>,
5151
) -> Result<FunctionResultStream>;
5252
}
5353

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
use std::sync::Arc;
2+
use std::time::Duration;
3+
use dashmap::DashMap;
4+
use tokio::sync::{mpsc, Mutex};
5+
use tokio::time::{sleep, timeout};
6+
7+
use baml_types::tracing::events::{FunctionId, TraceData, TraceEvent}; // Assuming TraceEvent is defined
8+
use crate::{
9+
tracingv2::storage::storage::{FunctionTrackerTrait, BAML_TRACER},
10+
BamlRuntime,
11+
};
12+
13+
/// Messages sent to the collector task.
14+
pub enum CollectorMsg {
15+
/// Instruct the collector to gracefully shutdown.
16+
Shutdown,
17+
}
18+
19+
enum NetworkMsg {
20+
/// Send a batch of events to the S3 pusher task.
21+
SendEvents(Vec<Arc<TraceEvent>>),
22+
/// Send a shutdown signal to the S3 pusher task.
23+
Shutdown,
24+
}
25+
26+
#[derive(Debug)]
27+
pub struct Collector {
28+
tracked_ids: Arc<DashMap<FunctionId, usize>>,
29+
shutdown_tx: mpsc::Sender<CollectorMsg>,
30+
// Handle for the main collector task.
31+
join_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
32+
// Channel to send events to the S3 pusher task.
33+
s3_tx: mpsc::Sender<NetworkMsg>,
34+
// Handle for the S3 pusher task.
35+
s3_join_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
36+
}
37+
38+
impl FunctionTrackerTrait for Collector {
39+
fn track_function(&self, fid: FunctionId) {
40+
log::trace!("Tracking function: {:?}", fid);
41+
// Increment the global ref count.
42+
BAML_TRACER.lock().unwrap().inc_ref(&fid);
43+
// Add to our set.
44+
self.tracked_ids.insert(fid, 0);
45+
}
46+
47+
fn untrack_function(&self, fid: &FunctionId) {
48+
self.tracked_ids.remove(fid);
49+
}
50+
}
51+
52+
impl Collector {
53+
/// Creates a new collector and spawns its background tasks.
54+
/// `tps` sets the number of update ticks per second.
55+
pub async fn new(tps: u32, runtime: &BamlRuntime) -> Arc<Self> {
56+
let hash = runtime.create_hash();
57+
58+
// Channel for shutdown signaling to the collector task.
59+
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
60+
// Channel for sending event batches to the S3 pusher.
61+
let (s3_tx, s3_rx) = mpsc::channel(100);
62+
63+
let collector = Arc::new(Self {
64+
tracked_ids: Arc::new(DashMap::new()),
65+
shutdown_tx,
66+
join_handle: Mutex::new(None),
67+
s3_tx,
68+
s3_join_handle: Mutex::new(None),
69+
});
70+
71+
// Spawn the main collector task.
72+
let main_handle = Self::start(Arc::clone(&collector), tps, shutdown_rx);
73+
{
74+
let mut join_lock = futures::executor::block_on(collector.join_handle.lock());
75+
*join_lock = Some(main_handle);
76+
}
77+
78+
// Spawn the S3 pusher task.
79+
let s3_handle = Self::start_s3_pusher(s3_rx);
80+
{
81+
let mut s3_join_lock = futures::executor::block_on(collector.s3_join_handle.lock());
82+
*s3_join_lock = Some(s3_handle);
83+
}
84+
85+
collector
86+
}
87+
88+
/// Spawns the main collector async task which ticks at the given TPS.
89+
/// It checks for a shutdown signal on every tick.
90+
fn start(
91+
collector: Arc<Self>,
92+
tps: u32,
93+
mut shutdown_rx: mpsc::Receiver<CollectorMsg>,
94+
) -> tokio::task::JoinHandle<()> {
95+
let interval = Duration::from_secs(1) / tps;
96+
tokio::spawn(async move {
97+
loop {
98+
tokio::select! {
99+
// Listen for a shutdown signal.
100+
_ = shutdown_rx.recv() => {
101+
collector.update_events().await;
102+
break;
103+
},
104+
// Regular tick: process events.
105+
_ = sleep(interval) => {
106+
collector.update_events().await;
107+
}
108+
}
109+
}
110+
})
111+
}
112+
113+
/// Spawns the S3 pusher task that listens for batches of events to push.
114+
fn start_s3_pusher(
115+
mut s3_rx: mpsc::Receiver<NetworkMsg>,
116+
) -> tokio::task::JoinHandle<()> {
117+
tokio::spawn(async move {
118+
while let Some(msg) = s3_rx.recv().await {
119+
match msg {
120+
NetworkMsg::SendEvents(events) => {
121+
// Call the async function to push events to S3.
122+
if let Err(e) = push_events_to_s3(events).await {
123+
log::error!("Failed to push events to S3: {:?}", e);
124+
}
125+
}
126+
NetworkMsg::Shutdown => {
127+
break;
128+
}
129+
}
130+
}
131+
log::info!("S3 pusher task shutting down.");
132+
})
133+
}
134+
135+
/// Processes new events from the tracer and cleans up finished function events.
136+
/// Also sends any gathered events to the S3 pusher task.
137+
async fn update_events(&self) {
138+
let events = {
139+
let tracer = BAML_TRACER.lock().unwrap();
140+
self.tracked_ids
141+
.iter_mut()
142+
.flat_map(|mut kv| {
143+
if let Some(events) = tracer.get_events(kv.key()) {
144+
// Get events beyond the last processed index.
145+
let last_event_index = *kv.value();
146+
let new_events = events
147+
.iter()
148+
.skip(last_event_index)
149+
.cloned()
150+
.collect::<Vec<_>>();
151+
*kv.value_mut() = new_events.len();
152+
new_events
153+
} else {
154+
vec![]
155+
}
156+
})
157+
.collect::<Vec<_>>()
158+
};
159+
160+
// Identify finished function events and untrack them.
161+
let finished_events: Vec<_> = events
162+
.iter()
163+
.filter(|e| matches!(e.content, TraceData::FunctionEnd(_)))
164+
.map(|e| &e.span_id)
165+
.collect();
166+
for fid in finished_events {
167+
self.untrack_function(fid);
168+
}
169+
170+
// If there are events to push, send them to the S3 pusher task.
171+
if !events.is_empty() {
172+
if let Err(e) = self.s3_tx.send(NetworkMsg::SendEvents(events)).await {
173+
log::error!("Failed to send events to S3 pusher: {:?}", e);
174+
}
175+
}
176+
}
177+
178+
/// Initiates a graceful shutdown of both the collector and S3 pusher tasks.
179+
/// Sends a shutdown signal and awaits task completion with a timeout.
180+
pub async fn shutdown(&self, timeout_duration: Duration) {
181+
// Send the shutdown signal to the main collector task.
182+
if let Err(e) = self.shutdown_tx.send(CollectorMsg::Shutdown).await {
183+
log::error!("Failed to send shutdown signal: {:?}", e);
184+
}
185+
// Wait for the main collector task to finish.
186+
if let Some(handle) = { let mut guard = self.join_handle.lock().await; guard.take() } {
187+
match timeout(timeout_duration, handle).await {
188+
Ok(result) => {
189+
if let Err(e) = result {
190+
log::error!("Collector task error: {:?}", e);
191+
}
192+
}
193+
Err(_) => {
194+
log::warn!("Timeout while waiting for the collector task to shut down.");
195+
}
196+
}
197+
}
198+
199+
// Signal the S3 pusher to shut down by closing its channel.
200+
self.s3_tx.send(NetworkMsg::Shutdown).await.unwrap();
201+
// Wait for the S3 pusher task to finish.
202+
if let Some(s3_handle) = { let mut guard = self.s3_join_handle.lock().await; guard.take() } {
203+
match timeout(timeout_duration, s3_handle).await {
204+
Ok(result) => {
205+
if let Err(e) = result {
206+
log::error!("S3 pusher task error: {:?}", e);
207+
}
208+
}
209+
Err(_) => {
210+
log::warn!("Timeout while waiting for the S3 pusher task to shut down.");
211+
}
212+
}
213+
}
214+
215+
if !self.tracked_ids.is_empty() {
216+
log::warn!("Some functions are still being tracked and will be dropped/canceled.");
217+
}
218+
}
219+
}
220+
221+
/// A placeholder async function simulating pushing events to S3.
222+
/// Replace this with your actual S3 upload logic.
223+
async fn push_events_to_s3(events: Vec<Arc<TraceEvent>>) -> Result<(), Box<dyn std::error::Error>> {
224+
log::info!("Pushing {} events to S3", events.len());
225+
// Simulate network delay.
226+
sleep(Duration::from_millis(100)).await;
227+
// TODO: implement real S3 push logic here.
228+
Ok(())
229+
}

0 commit comments

Comments
 (0)