diff --git a/.gitignore b/.gitignore index c89af90d7..82ded808b 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ meta-llama/** **/debug/**/*.json requirements-freeze*.txt /playground +.env # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/agentlightning/algorithm/apo/apo.py b/agentlightning/algorithm/apo/apo.py index 894ea8be2..8615da8bd 100644 --- a/agentlightning/algorithm/apo/apo.py +++ b/agentlightning/algorithm/apo/apo.py @@ -35,6 +35,11 @@ import poml from openai import AsyncOpenAI +try: + from litellm import acompletion as litellm_acompletion +except ImportError: + litellm_acompletion = None + from agentlightning.adapter.messages import TraceToMessages from agentlightning.algorithm.base import Algorithm from agentlightning.algorithm.utils import batch_iter_over_dataset, with_llm_proxy, with_store @@ -100,8 +105,9 @@ class APO(Algorithm, Generic[T_task]): def __init__( self, - async_openai_client: AsyncOpenAI, + async_openai_client: Optional[AsyncOpenAI] = None, *, + use_litellm: bool = False, gradient_model: str = "gpt-5-mini", apply_edit_model: str = "gpt-4.1-mini", diversity_temperature: float = 1.0, @@ -122,6 +128,8 @@ def __init__( Args: async_openai_client: AsyncOpenAI client for making LLM API calls. + Optional when ``use_litellm=True``. + use_litellm: If True, uses ``litellm.acompletion`` directly for LLM calls. gradient_model: Model name for computing textual gradients (critiques). apply_edit_model: Model name for applying edits based on critiques. diversity_temperature: Temperature parameter for LLM calls to control diversity. @@ -137,7 +145,11 @@ def __init__( gradient_prompt_files: Prompt templates used to compute textual gradients (critiques). apply_edit_prompt_files: Prompt templates used to apply edits based on critiques. """ + if use_litellm and litellm_acompletion is None: + raise ImportError("litellm is not installed but use_litellm=True was provided.") + self.async_openai_client = async_openai_client + self.use_litellm = use_litellm self.gradient_model = gradient_model self.apply_edit_model = apply_edit_model self.diversity_temperature = diversity_temperature @@ -159,6 +171,30 @@ def __init__( self._poml_trace = _poml_trace + async def _chat_completion_create( + self, + *, + model: str, + messages: Any, + temperature: float, + ) -> Any: + if self.use_litellm: + assert litellm_acompletion is not None + return await litellm_acompletion( + model=model, + messages=messages, + temperature=temperature, + ) + + if self.async_openai_client is None: + raise ValueError("async_openai_client must be provided when use_litellm=False") + + return await self.async_openai_client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + ) + def _create_versioned_prompt( self, prompt_template: PromptTemplate, @@ -307,7 +343,7 @@ async def compute_textual_gradient( f"Gradient computed with {self.gradient_model} prompt: {tg_msg}", prefix=prefix, ) - critique_response = await self.async_openai_client.chat.completions.create( + critique_response = await self._chat_completion_create( model=self.gradient_model, messages=tg_msg["messages"], # type: ignore temperature=self.diversity_temperature, @@ -373,7 +409,7 @@ async def textual_gradient_and_apply_edit( format="openai_chat", ) - ae_response = await self.async_openai_client.chat.completions.create( + ae_response = await self._chat_completion_create( model=self.apply_edit_model, messages=ae_msg["messages"], # type: ignore temperature=self.diversity_temperature, diff --git a/examples/apo/README.md b/examples/apo/README.md index a3f647853..4dd99ed08 100644 --- a/examples/apo/README.md +++ b/examples/apo/README.md @@ -19,11 +19,12 @@ Follow the [installation guide](../../docs/tutorials/installation.md) to install | `room_selector.py` | Room booking agent implementation using function calling | | `room_selector_apo.py` | Training script using the built-in APO algorithm to optimize prompts | | `room_tasks.jsonl` | Dataset with room booking scenarios and expected selections | -| `apo_custom_algorithm.py` | Tutorial on creating custom algorithms (runnable as algo or runner) | +| `apo_custom_algorithm.py` | Tutorial on creating custom algorithms | | `apo_custom_algorithm_trainer.py` | Shows how to integrate custom algorithms into the Trainer | | `apo_debug.py` | Tutorial demonstrating various agent debugging techniques | -| `legacy_apo_client.py` | Deprecated APO client implementation compatible with Agent-lightning v0.1.x | -| `legacy_apo_server.py` | Deprecated APO server implementation compatible with Agent-lightning v0.1.x | +| `apo_multi_provider.py` | Hybrid optimization sample using multiple LLM backends | +| `legacy_apo_client.py` | Deprecated APO client implementation compatible with v0.1.x | +| `legacy_apo_server.py` | Deprecated APO server implementation compatible with v0.1.x | ## Sample 1: Using Built-in APO Algorithm diff --git a/examples/apo/apo_custom_algorithm.py b/examples/apo/apo_custom_algorithm.py index 45ba529ec..c92ae42eb 100644 --- a/examples/apo/apo_custom_algorithm.py +++ b/examples/apo/apo_custom_algorithm.py @@ -89,7 +89,7 @@ async def apo_algorithm(*, store: agl.LightningStore): await log_llm_span(spans) # 4. The algorithm records the final reward for sorting - final_reward = agl.find_final_reward(spans) + final_reward = agl.find_final_reward(spans) # type: ignore assert final_reward is not None, "Expected a final reward from the client." console.print(f"{algo_marker} Final reward: {final_reward}") prompt_and_rewards.append((prompt, final_reward)) diff --git a/examples/apo/apo_multi_provider.py b/examples/apo/apo_multi_provider.py new file mode 100644 index 000000000..94d7eebb8 --- /dev/null +++ b/examples/apo/apo_multi_provider.py @@ -0,0 +1,139 @@ +""" +This sample code demonstrates how to use APO's built-in LiteLLM mode +to tune mathematical reasoning prompts using a hybrid model setup. +""" + +import logging +import re +import asyncio +import multiprocessing +from typing import Tuple, cast, Dict, Any, List, Callable + +from dotenv import load_dotenv +load_dotenv() +import agentlightning as agl +from agentlightning import Trainer, setup_logging, PromptTemplate +from agentlightning.adapter import TraceToMessages +from agentlightning.algorithm.apo import APO +from agentlightning.types import Dataset +import litellm + +completion: Callable[..., Any] = cast(Callable[..., Any], getattr(litellm, "completion")) + + +# --- 1. Dataset Logic --- +def load_math_tasks() -> List[Dict[str, str]]: + """Small mock GSM8k-style dataset.""" + return [ + {"question": "If I have 3 apples and buy 2 more, how many do I have?", "expected": "5"}, + {"question": "A train travels 60 miles in 1 hour. How far in 3 hours?", "expected": "180"}, + {"question": "What is the square root of 144?", "expected": "12"}, + {"question": "If a shirt costs $20 and is 10% off, what is the price?", "expected": "18"}, + ] + +def load_train_val_dataset() -> Tuple[Dataset[Dict[str, str]], Dataset[Dict[str, str]]]: + dataset_full = load_math_tasks() + train_split = len(dataset_full) // 2 + # Use list() and cast to satisfy Pylance's SupportsIndex/slice checks + dataset_train = cast(Dataset[Dict[str, str]], list(dataset_full[:train_split])) + dataset_val = cast(Dataset[Dict[str, str]], list(dataset_full[train_split:])) + return dataset_train, dataset_val + +# --- 2. Agent Logic --- +class MathAgent(agl.LitAgent[Dict[str, str]]): + def __init__(self): + super().__init__() + + def rollout(self, task: Any, resources: Dict[str, Any], rollout: Any) -> float: + # Pylance fix: Explicitly cast task to Dict + t = cast(Dict[str, str], task) + prompt_template: PromptTemplate = resources.get("prompt_template") # type: ignore + + # Ensure template access is type-safe + template_str = getattr(prompt_template, "template", str(prompt_template)) + prompt = template_str.format(question=t["question"]) + + # Direct LiteLLM call + response = completion( + model="gemini/gemini-2.0-flash", + messages=[{"role": "user", "content": prompt}] + ) + response_obj = response + content = response_obj.choices[0].message.content + answer = content if isinstance(content, str) else "" + + # Reward: Numerical exact match check + pred_nums = re.findall(r"[-+]?\d*\.\d+|\d+", answer.split("Answer:")[-1]) + reward = 1.0 if pred_nums and pred_nums[-1] == t["expected"] else 0.0 + + agl.emit_reward(reward) + return reward + +# --- 3. Logging & Main --- +def setup_apo_logger(file_path: str = "apo_math.log") -> None: + file_handler = logging.FileHandler(file_path) + file_handler.setLevel(logging.INFO) + formatter = logging.Formatter("%(asctime)s [%(levelname)s] (%(name)s) %(message)s") + file_handler.setFormatter(formatter) + logging.getLogger("agentlightning.algorithm.apo").addHandler(file_handler) + +def main() -> None: + setup_logging() + setup_apo_logger() + + initial_prompt_str = "Solve: {question}" + + algo = APO[Dict[str, str]]( + use_litellm=True, + gradient_model="gemini/gemini-2.0-flash", + apply_edit_model="groq/llama-3.3-70b-versatile", + val_batch_size=2, + gradient_batch_size=2, + beam_width=1, + branch_factor=1, + beam_rounds=1, + ) + + trainer = Trainer( + algorithm=algo, + n_runners=2, + initial_resources={ + "prompt_template": PromptTemplate(template=initial_prompt_str, engine="f-string") + }, + adapter=TraceToMessages(), + ) + + dataset_train, dataset_val = load_train_val_dataset() + agent = MathAgent() + + print("\n" + "="*60) + print("🚀 HYBRID APO OPTIMIZATION STARTING") + print("-" * 60) + + trainer.fit(agent=agent, train_dataset=dataset_train, val_dataset=dataset_val) + + # Print Final Prompt from the store + print("\n" + "="*60) + print("✅ OPTIMIZATION COMPLETE") + print("-" * 60) + print(f"INITIAL PROMPT:\n{initial_prompt_str}") + + + # Accessing the latest optimized prompt from the trainer store + try: + latest_resources = asyncio.run(trainer.store.query_resources()) + if latest_resources: + final_res = latest_resources[-1].resources.get("prompt_template") + final_prompt = getattr(final_res, "template", str(final_res)) + print(f"FINAL OPTIMIZED PROMPT:\n{final_prompt}") + except Exception as e: + print(f"Optimization finished. Check apo_math.log for detailed iteration results. Error: {e}") + + print("="*60 + "\n") + +if __name__ == "__main__": + try: + multiprocessing.set_start_method("fork", force=True) + except RuntimeError: + pass + main() \ No newline at end of file diff --git a/tests/algorithm/test_apo.py b/tests/algorithm/test_apo.py index 90d8efd0d..78c3e5611 100644 --- a/tests/algorithm/test_apo.py +++ b/tests/algorithm/test_apo.py @@ -311,6 +311,26 @@ async def test_compute_textual_gradient_uses_all_rollouts_when_insufficient(monk assert result == "critique" +@pytest.mark.asyncio +async def test_compute_textual_gradient_uses_litellm_backend(monkeypatch: pytest.MonkeyPatch) -> None: + litellm_mock = AsyncMock(return_value=make_completion("critique")) + monkeypatch.setattr(apo_module, "litellm_acompletion", litellm_mock) + monkeypatch.setattr(apo_module.random, "choice", lambda seq: seq[0]) # type: ignore + + apo = APO[Any](use_litellm=True, gradient_model="test-gradient-model", gradient_batch_size=1) + versioned_prompt = apo._create_versioned_prompt(PromptTemplate(template="prompt", engine="f-string")) + rollouts: List[RolloutResultForAPO] = [ + RolloutResultForAPO(status="succeeded", final_reward=1.0, spans=[], messages=[]) + ] + + result = await apo.compute_textual_gradient(versioned_prompt, rollouts) + + assert result == "critique" + litellm_mock.assert_awaited_once() + call_kwargs = litellm_mock.await_args.kwargs # type: ignore + assert call_kwargs["model"] == "test-gradient-model" + + @pytest.mark.asyncio async def test_textual_gradient_and_apply_edit_returns_new_prompt(monkeypatch: pytest.MonkeyPatch) -> None: # Use two separate mocks for gradient and edit calls