Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ LANGSMITH_API_KEY=
LANGSMITH_PROJECT=
LANGSMITH_TRACING=

GROQ_API_KEY=
GENSEE_API_KEY=

# Only necessary for Open Agent Platform
Expand Down
30 changes: 29 additions & 1 deletion src/open_deep_research/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
from enum import Enum
from typing import Any, List, Optional
from typing import Any, List, Optional, Literal

from langchain_core.runnables import RunnableConfig
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -75,6 +75,34 @@ class Configuration(BaseModel):
}
}
)
# Parallel Supervisors Configuration
max_concurrent_supervisors: int = Field(
default=3,
metadata={
"x_oap_ui_config": {
"type": "slider",
"default": 3,
"min": 1,
"max": 10,
"step": 1,
"description": "Maximum number of supervisor subgraphs to run in parallel after the research brief."
}
}
)
parallel_supervisor_strategy: Literal["early_stop", "aggregate"] = Field(
default="early_stop",
metadata={
"x_oap_ui_config": {
"type": "select",
"default": "early_stop",
"description": "Whether parallel supervisors receive the identical brief or variants of it.",
"options": [
{"label": "Early Stop", "value": "early_stop"},
{"label": "Aggregate", "value": "aggregate"}
]
}
}
)
# Research Configuration
search_api: SearchAPI = Field(
default=SearchAPI.TAVILY,
Expand Down
111 changes: 103 additions & 8 deletions src/open_deep_research/deep_researcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Main LangGraph implementation for the Deep Research agent."""

import asyncio
import json
from typing import Literal

from langchain.chat_models import init_chat_model
Expand Down Expand Up @@ -86,19 +87,29 @@ async def clarify_with_user(state: AgentState, config: RunnableConfig) -> Comman
}

# Configure model with structured output and retry logic
clarification_model = (
configurable_model
.with_structured_output(ClarifyWithUser)
.with_retry(stop_after_attempt=configurable.max_structured_output_retries)
.with_config(model_config)
)
if "groq" in configurable.research_model:
clarification_model = (
configurable_model
.with_retry(stop_after_attempt=configurable.max_structured_output_retries)
.with_config(model_config)
)
else:
clarification_model = (
configurable_model
.with_structured_output(ClarifyWithUser)
.with_retry(stop_after_attempt=configurable.max_structured_output_retries)
.with_config(model_config)
)

# Step 3: Analyze whether clarification is needed
prompt_content = clarify_with_user_instructions.format(
messages=get_buffer_string(messages),
date=get_today_str()
)
response = await clarification_model.ainvoke([HumanMessage(content=prompt_content)])
if "groq" in configurable.research_model:
response = json.loads(response.content)
response = ClarifyWithUser(**response)

# Step 4: Route based on clarification analysis
if response.need_clarification:
Expand Down Expand Up @@ -145,14 +156,29 @@ async def write_research_brief(state: AgentState, config: RunnableConfig) -> Com
.with_retry(stop_after_attempt=configurable.max_structured_output_retries)
.with_config(research_model_config)
)
if "groq" in configurable.research_model:
research_model = (
configurable_model
.with_retry(stop_after_attempt=configurable.max_structured_output_retries)
.with_config(research_model_config)
)
else:
research_model = (
configurable_model
.with_structured_output(ResearchQuestion)
.with_retry(stop_after_attempt=configurable.max_structured_output_retries)
.with_config(research_model_config)
)

# Step 2: Generate structured research brief from user messages
prompt_content = transform_messages_into_research_topic_prompt.format(
messages=get_buffer_string(state.get("messages", [])),
date=get_today_str()
)
response = await research_model.ainvoke([HumanMessage(content=prompt_content)])

if "groq" in configurable.research_model:
response = ResearchQuestion(research_brief=response.content)

# Step 3: Initialize supervisor with research brief and instructions
supervisor_system_prompt = lead_researcher_prompt.format(
date=get_today_str(),
Expand Down Expand Up @@ -362,6 +388,75 @@ async def supervisor_tools(state: SupervisorState, config: RunnableConfig) -> Co
# Compile supervisor subgraph for use in main workflow
supervisor_subgraph = supervisor_builder.compile()

# -----------------------------
# Parallel Supervisors Orchestrator
# -----------------------------
async def multiple_supervisors(state: AgentState, config: RunnableConfig) -> Command[Literal["final_report_generation"]]:
"""Spawn multiple supervisor subgraphs in parallel and aggregate OR early stop the process.

Each supervisor receives the identical research brief.
Notes(aggregated if chosen) are forwarded to final report.
"""
configurable = Configuration.from_runnable_config(config)
research_brief = state.get("research_brief", "")

# Determine number of supervisors to launch (apply a hard cap for safety)
hard_cap = 10
num_supervisors = max(1, min(configurable.max_concurrent_supervisors, hard_cap))

briefs_for_supervisors = [research_brief for _ in range(num_supervisors)]

# Prepare supervisor system prompt once
supervisor_system_prompt = lead_researcher_prompt.format(
date=get_today_str(),
max_concurrent_research_units=configurable.max_concurrent_research_units,
max_researcher_iterations=configurable.max_researcher_iterations,
)

# Launch all supervisors in parallel
tasks = [
asyncio.create_task(
supervisor_subgraph.ainvoke(
{
"supervisor_messages": {
"type": "override",
"value": [
SystemMessage(content=supervisor_system_prompt),
HumanMessage(content=brief),
],
},
"research_brief": research_brief,
},
config,
)
)
for brief in briefs_for_supervisors
]

if configurable.parallel_supervisor_strategy == "early_stop":
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
winner_result = await next(iter(done))
# Cancel remaining tasks
for t in pending:
t.cancel()
# Optionally, drain cancellations to avoid warnings
await asyncio.gather(*pending, return_exceptions=True)

elif configurable.parallel_supervisor_strategy == "aggregate":
results = await asyncio.gather(*tasks)
final_result = {
"notes": [result.get("notes", []) for result in results],
}

return Command(
goto="final_report_generation",
update={
"notes": final_result.get("notes", []),
"research_brief": research_brief,
},
)


async def researcher(state: ResearcherState, config: RunnableConfig) -> Command[Literal["researcher_tools"]]:
"""Individual researcher that conducts focused research on specific topics.

Expand Down Expand Up @@ -707,7 +802,7 @@ async def final_report_generation(state: AgentState, config: RunnableConfig):
# Add main workflow nodes for the complete research process
deep_researcher_builder.add_node("clarify_with_user", clarify_with_user) # User clarification phase
deep_researcher_builder.add_node("write_research_brief", write_research_brief) # Research planning phase
deep_researcher_builder.add_node("research_supervisor", supervisor_subgraph) # Research execution phase
deep_researcher_builder.add_node("research_supervisor", multiple_supervisors) # Research execution phase
deep_researcher_builder.add_node("final_report_generation", final_report_generation) # Report generation phase

# Define main workflow edges for sequential execution
Expand Down
4 changes: 4 additions & 0 deletions src/open_deep_research/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,8 @@ def get_api_key_for_model(model_name: str, config: RunnableConfig):
return api_keys.get("ANTHROPIC_API_KEY")
elif model_name.startswith("google"):
return api_keys.get("GOOGLE_API_KEY")
elif model_name.startswith("groq"):
return api_keys.get("GROQ_API_KEY")
return None
else:
if model_name.startswith("openai:"):
Expand All @@ -1040,6 +1042,8 @@ def get_api_key_for_model(model_name: str, config: RunnableConfig):
return os.getenv("ANTHROPIC_API_KEY")
elif model_name.startswith("google"):
return os.getenv("GOOGLE_API_KEY")
elif model_name.startswith("groq"):
return os.getenv("GROQ_API_KEY")
return None

def get_tavily_api_key(config: RunnableConfig):
Expand Down