Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions cli/planoai/config_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,15 +562,15 @@ def validate_and_render_schema():
"Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers"
)

# Validate input_filters IDs on listeners reference valid agent/filter IDs
# Validate listener-level filter IDs reference valid agent/filter IDs.
for listener in listeners:
listener_input_filters = listener.get("input_filters", [])
for fc_id in listener_input_filters:
if fc_id not in agent_id_keys:
raise Exception(
f"Listener '{listener.get('name', 'unknown')}' references input_filters id '{fc_id}' "
f"which is not defined in agents or filters. Available ids: {', '.join(sorted(agent_id_keys))}"
)
for filter_field in ("input_filters", "output_filters"):
for fc_id in listener.get(filter_field, []):
if fc_id not in agent_id_keys:
raise Exception(
f"Listener '{listener.get('name', 'unknown')}' references {filter_field} id '{fc_id}' "
f"which is not defined in agents or filters. Available ids: {', '.join(sorted(agent_id_keys))}"
)

# Validate model aliases if present
if "model_aliases" in config_yaml:
Expand Down
57 changes: 57 additions & 0 deletions cli/test/test_config_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,63 @@ def test_validate_and_render_happy_path_agent_config(monkeypatch):
tracing:
random_sampling: 100

""",
},
{
"id": "unknown_listener_output_filter",
"expected_error": "references output_filters id 'missing_output_guard'",
"plano_config": """
version: v0.4.0

filters:
- id: input_guard
url: http://localhost:10500
type: http

listeners:
- name: llm
type: model
port: 12000
input_filters:
- input_guard
output_filters:
- missing_output_guard

model_providers:
- model: openai/gpt-4o-mini
access_key: $OPENAI_API_KEY
default: true

""",
},
{
"id": "valid_listener_output_filter",
"expected_error": None,
"plano_config": """
version: v0.4.0

filters:
- id: input_guard
url: http://localhost:10500
type: http
- id: output_guard
url: http://localhost:10501
type: http

listeners:
- name: llm
type: model
port: 12000
input_filters:
- input_guard
output_filters:
- output_guard

model_providers:
- model: openai/gpt-4o-mini
access_key: $OPENAI_API_KEY
default: true

""",
},
]
Expand Down
113 changes: 95 additions & 18 deletions crates/brightstaff/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,25 +142,19 @@ async fn init_app_state(
.listeners
.iter()
.find(|l| l.listener_type == ListenerType::Model);
let resolve_chain = |filter_ids: Option<Vec<String>>| -> Option<ResolvedFilterChain> {
filter_ids.map(|ids| {
let agents = ids
.iter()
.filter_map(|id| {
global_agent_map
.get(id)
.map(|a: &Agent| (id.clone(), a.clone()))
})
.collect();
ResolvedFilterChain {
filter_ids: ids,
agents,
}
})
};
let filter_pipeline = Arc::new(FilterPipeline {
input: resolve_chain(model_listener.and_then(|l| l.input_filters.clone())),
output: resolve_chain(model_listener.and_then(|l| l.output_filters.clone())),
input: resolve_filter_chain(
"input_filters",
model_listener.and_then(|l| l.input_filters.clone()),
&global_agent_map,
)
.map_err(|e| format!("failed to resolve model listener input filters: {e}"))?,
output: resolve_filter_chain(
"output_filters",
model_listener.and_then(|l| l.output_filters.clone()),
&global_agent_map,
)
.map_err(|e| format!("failed to resolve model listener output filters: {e}"))?,
});

let overrides = config.overrides.clone().unwrap_or_default();
Expand Down Expand Up @@ -350,6 +344,29 @@ async fn init_app_state(
})
}

fn resolve_filter_chain(
field_name: &str,
filter_ids: Option<Vec<String>>,
global_agent_map: &HashMap<String, Agent>,
) -> Result<Option<ResolvedFilterChain>, String> {
let Some(ids) = filter_ids else {
return Ok(None);
};

let mut agents = HashMap::new();
for id in &ids {
let agent = global_agent_map
.get(id)
.ok_or_else(|| format!("{field_name} id '{id}' is not defined in agents or filters"))?;
agents.insert(id.clone(), agent.clone());
}

Ok(Some(ResolvedFilterChain {
filter_ids: ids,
agents,
}))
}

/// Initialize the conversation state storage backend (if configured).
async fn init_state_storage(
config: &Configuration,
Expand Down Expand Up @@ -588,3 +605,63 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let state = Arc::new(init_app_state(&config).await?);
run_server(state).await
}

#[cfg(test)]
mod tests {
use super::*;

fn test_agent(id: &str) -> Agent {
Agent {
id: id.to_string(),
transport: None,
tool: None,
url: "http://localhost:10500".to_string(),
agent_type: Some("http".to_string()),
}
}

#[test]
fn resolve_filter_chain_keeps_valid_filter_references() {
let agent = test_agent("output_guard");
let global_agent_map = HashMap::from([(agent.id.clone(), agent)]);

let resolved = resolve_filter_chain(
"output_filters",
Some(vec!["output_guard".to_string()]),
&global_agent_map,
)
.expect("filter chain should resolve")
.expect("filter chain should be present");

assert_eq!(resolved.filter_ids, vec!["output_guard".to_string()]);
assert!(resolved.agents.contains_key("output_guard"));
}

#[test]
fn resolve_filter_chain_errors_on_missing_output_filter_reference() {
let global_agent_map = HashMap::new();

let err = resolve_filter_chain(
"output_filters",
Some(vec!["missing_output_guard".to_string()]),
&global_agent_map,
)
.expect_err("missing output filter should fail closed");

assert!(err.contains("output_filters id 'missing_output_guard'"));
}

#[test]
fn resolve_filter_chain_errors_on_missing_input_filter_reference() {
let global_agent_map = HashMap::new();

let err = resolve_filter_chain(
"input_filters",
Some(vec!["missing_input_guard".to_string()]),
&global_agent_map,
)
.expect_err("missing input filter should fail closed");

assert!(err.contains("input_filters id 'missing_input_guard'"));
}
}
Loading