Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ meta-llama/**
**/debug/**/*.json
requirements-freeze*.txt
/playground
.env

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
42 changes: 39 additions & 3 deletions agentlightning/algorithm/apo/apo.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
import poml
from openai import AsyncOpenAI

try:
from litellm import acompletion as litellm_acompletion
except ImportError:
litellm_acompletion = None

from agentlightning.adapter.messages import TraceToMessages
from agentlightning.algorithm.base import Algorithm
from agentlightning.algorithm.utils import batch_iter_over_dataset, with_llm_proxy, with_store
Expand Down Expand Up @@ -100,8 +105,9 @@ class APO(Algorithm, Generic[T_task]):

def __init__(
self,
async_openai_client: AsyncOpenAI,
async_openai_client: Optional[AsyncOpenAI] = None,
*,
use_litellm: bool = False,
gradient_model: str = "gpt-5-mini",
apply_edit_model: str = "gpt-4.1-mini",
diversity_temperature: float = 1.0,
Expand All @@ -122,6 +128,8 @@ def __init__(

Args:
async_openai_client: AsyncOpenAI client for making LLM API calls.
Optional when ``use_litellm=True``.
use_litellm: If True, uses ``litellm.acompletion`` directly for LLM calls.
gradient_model: Model name for computing textual gradients (critiques).
apply_edit_model: Model name for applying edits based on critiques.
diversity_temperature: Temperature parameter for LLM calls to control diversity.
Expand All @@ -137,7 +145,11 @@ def __init__(
gradient_prompt_files: Prompt templates used to compute textual gradients (critiques).
apply_edit_prompt_files: Prompt templates used to apply edits based on critiques.
"""
if use_litellm and litellm_acompletion is None:
raise ImportError("litellm is not installed but use_litellm=True was provided.")

self.async_openai_client = async_openai_client
self.use_litellm = use_litellm
self.gradient_model = gradient_model
self.apply_edit_model = apply_edit_model
self.diversity_temperature = diversity_temperature
Expand All @@ -159,6 +171,30 @@ def __init__(

self._poml_trace = _poml_trace

async def _chat_completion_create(
self,
*,
model: str,
messages: Any,
temperature: float,
) -> Any:
if self.use_litellm:
assert litellm_acompletion is not None
return await litellm_acompletion(
model=model,
messages=messages,
temperature=temperature,
)

if self.async_openai_client is None:
raise ValueError("async_openai_client must be provided when use_litellm=False")

return await self.async_openai_client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
)

def _create_versioned_prompt(
self,
prompt_template: PromptTemplate,
Expand Down Expand Up @@ -307,7 +343,7 @@ async def compute_textual_gradient(
f"Gradient computed with {self.gradient_model} prompt: {tg_msg}",
prefix=prefix,
)
critique_response = await self.async_openai_client.chat.completions.create(
critique_response = await self._chat_completion_create(
model=self.gradient_model,
messages=tg_msg["messages"], # type: ignore
temperature=self.diversity_temperature,
Expand Down Expand Up @@ -373,7 +409,7 @@ async def textual_gradient_and_apply_edit(
format="openai_chat",
)

ae_response = await self.async_openai_client.chat.completions.create(
ae_response = await self._chat_completion_create(
model=self.apply_edit_model,
messages=ae_msg["messages"], # type: ignore
temperature=self.diversity_temperature,
Expand Down
7 changes: 4 additions & 3 deletions examples/apo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ Follow the [installation guide](../../docs/tutorials/installation.md) to install
| `room_selector.py` | Room booking agent implementation using function calling |
| `room_selector_apo.py` | Training script using the built-in APO algorithm to optimize prompts |
| `room_tasks.jsonl` | Dataset with room booking scenarios and expected selections |
| `apo_custom_algorithm.py` | Tutorial on creating custom algorithms (runnable as algo or runner) |
| `apo_custom_algorithm.py` | Tutorial on creating custom algorithms |
| `apo_custom_algorithm_trainer.py` | Shows how to integrate custom algorithms into the Trainer |
| `apo_debug.py` | Tutorial demonstrating various agent debugging techniques |
| `legacy_apo_client.py` | Deprecated APO client implementation compatible with Agent-lightning v0.1.x |
| `legacy_apo_server.py` | Deprecated APO server implementation compatible with Agent-lightning v0.1.x |
| `apo_multi_provider.py` | Hybrid optimization sample using multiple LLM backends |
| `legacy_apo_client.py` | Deprecated APO client implementation compatible with v0.1.x |
| `legacy_apo_server.py` | Deprecated APO server implementation compatible with v0.1.x |

## Sample 1: Using Built-in APO Algorithm

Expand Down
2 changes: 1 addition & 1 deletion examples/apo/apo_custom_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def apo_algorithm(*, store: agl.LightningStore):
await log_llm_span(spans)

# 4. The algorithm records the final reward for sorting
final_reward = agl.find_final_reward(spans)
final_reward = agl.find_final_reward(spans) # type: ignore
assert final_reward is not None, "Expected a final reward from the client."
console.print(f"{algo_marker} Final reward: {final_reward}")
prompt_and_rewards.append((prompt, final_reward))
Expand Down
139 changes: 139 additions & 0 deletions examples/apo/apo_multi_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""
This sample code demonstrates how to use APO's built-in LiteLLM mode
to tune mathematical reasoning prompts using a hybrid model setup.
"""

import logging
import re
import asyncio
import multiprocessing
from typing import Tuple, cast, Dict, Any, List, Callable

from dotenv import load_dotenv
load_dotenv()
import agentlightning as agl
from agentlightning import Trainer, setup_logging, PromptTemplate
from agentlightning.adapter import TraceToMessages
from agentlightning.algorithm.apo import APO
from agentlightning.types import Dataset
import litellm

completion: Callable[..., Any] = cast(Callable[..., Any], getattr(litellm, "completion"))


# --- 1. Dataset Logic ---
def load_math_tasks() -> List[Dict[str, str]]:
"""Small mock GSM8k-style dataset."""
return [
{"question": "If I have 3 apples and buy 2 more, how many do I have?", "expected": "5"},
{"question": "A train travels 60 miles in 1 hour. How far in 3 hours?", "expected": "180"},
{"question": "What is the square root of 144?", "expected": "12"},
{"question": "If a shirt costs $20 and is 10% off, what is the price?", "expected": "18"},
]

def load_train_val_dataset() -> Tuple[Dataset[Dict[str, str]], Dataset[Dict[str, str]]]:
dataset_full = load_math_tasks()
train_split = len(dataset_full) // 2
# Use list() and cast to satisfy Pylance's SupportsIndex/slice checks
dataset_train = cast(Dataset[Dict[str, str]], list(dataset_full[:train_split]))
dataset_val = cast(Dataset[Dict[str, str]], list(dataset_full[train_split:]))
return dataset_train, dataset_val

# --- 2. Agent Logic ---
class MathAgent(agl.LitAgent[Dict[str, str]]):
def __init__(self):
super().__init__()

def rollout(self, task: Any, resources: Dict[str, Any], rollout: Any) -> float:
# Pylance fix: Explicitly cast task to Dict
t = cast(Dict[str, str], task)
prompt_template: PromptTemplate = resources.get("prompt_template") # type: ignore

# Ensure template access is type-safe
template_str = getattr(prompt_template, "template", str(prompt_template))
prompt = template_str.format(question=t["question"])

# Direct LiteLLM call
response = completion(
model="gemini/gemini-2.0-flash",
messages=[{"role": "user", "content": prompt}]
)
response_obj = response
content = response_obj.choices[0].message.content
answer = content if isinstance(content, str) else ""

# Reward: Numerical exact match check
pred_nums = re.findall(r"[-+]?\d*\.\d+|\d+", answer.split("Answer:")[-1])
reward = 1.0 if pred_nums and pred_nums[-1] == t["expected"] else 0.0

agl.emit_reward(reward)
return reward

# --- 3. Logging & Main ---
def setup_apo_logger(file_path: str = "apo_math.log") -> None:
file_handler = logging.FileHandler(file_path)
file_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s [%(levelname)s] (%(name)s) %(message)s")
file_handler.setFormatter(formatter)
logging.getLogger("agentlightning.algorithm.apo").addHandler(file_handler)

def main() -> None:
setup_logging()
setup_apo_logger()

initial_prompt_str = "Solve: {question}"

algo = APO[Dict[str, str]](
use_litellm=True,
gradient_model="gemini/gemini-2.0-flash",
apply_edit_model="groq/llama-3.3-70b-versatile",
val_batch_size=2,
gradient_batch_size=2,
beam_width=1,
branch_factor=1,
beam_rounds=1,
)

trainer = Trainer(
algorithm=algo,
n_runners=2,
initial_resources={
"prompt_template": PromptTemplate(template=initial_prompt_str, engine="f-string")
},
adapter=TraceToMessages(),
)

dataset_train, dataset_val = load_train_val_dataset()
agent = MathAgent()

print("\n" + "="*60)
print("πŸš€ HYBRID APO OPTIMIZATION STARTING")
print("-" * 60)

trainer.fit(agent=agent, train_dataset=dataset_train, val_dataset=dataset_val)

# Print Final Prompt from the store
print("\n" + "="*60)
print("βœ… OPTIMIZATION COMPLETE")
print("-" * 60)
print(f"INITIAL PROMPT:\n{initial_prompt_str}")


# Accessing the latest optimized prompt from the trainer store
try:
latest_resources = asyncio.run(trainer.store.query_resources())
if latest_resources:
final_res = latest_resources[-1].resources.get("prompt_template")
final_prompt = getattr(final_res, "template", str(final_res))
print(f"FINAL OPTIMIZED PROMPT:\n{final_prompt}")
except Exception as e:
print(f"Optimization finished. Check apo_math.log for detailed iteration results. Error: {e}")

print("="*60 + "\n")

if __name__ == "__main__":
try:
multiprocessing.set_start_method("fork", force=True)
except RuntimeError:
pass
main()
20 changes: 20 additions & 0 deletions tests/algorithm/test_apo.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,26 @@ async def test_compute_textual_gradient_uses_all_rollouts_when_insufficient(monk
assert result == "critique"


@pytest.mark.asyncio
async def test_compute_textual_gradient_uses_litellm_backend(monkeypatch: pytest.MonkeyPatch) -> None:
litellm_mock = AsyncMock(return_value=make_completion("critique"))
monkeypatch.setattr(apo_module, "litellm_acompletion", litellm_mock)
monkeypatch.setattr(apo_module.random, "choice", lambda seq: seq[0]) # type: ignore

apo = APO[Any](use_litellm=True, gradient_model="test-gradient-model", gradient_batch_size=1)
versioned_prompt = apo._create_versioned_prompt(PromptTemplate(template="prompt", engine="f-string"))
rollouts: List[RolloutResultForAPO] = [
RolloutResultForAPO(status="succeeded", final_reward=1.0, spans=[], messages=[])
]

result = await apo.compute_textual_gradient(versioned_prompt, rollouts)

assert result == "critique"
litellm_mock.assert_awaited_once()
call_kwargs = litellm_mock.await_args.kwargs # type: ignore
assert call_kwargs["model"] == "test-gradient-model"


@pytest.mark.asyncio
async def test_textual_gradient_and_apply_edit_returns_new_prompt(monkeypatch: pytest.MonkeyPatch) -> None:
# Use two separate mocks for gradient and edit calls
Expand Down