From 26fd6bda8a903b2c0838938f99035df531d1e976 Mon Sep 17 00:00:00 2001 From: Ekin Bozyel Date: Sun, 1 Feb 2026 21:20:14 +0300 Subject: [PATCH 1/7] feat: add MultiProviderClient for multi-LLM provider routing --- .../utils/multi_provider_ai_client.py | 135 ++++++++++++++++++ test.py | 8 ++ 2 files changed, 143 insertions(+) create mode 100644 agentlightning/utils/multi_provider_ai_client.py create mode 100644 test.py diff --git a/agentlightning/utils/multi_provider_ai_client.py b/agentlightning/utils/multi_provider_ai_client.py new file mode 100644 index 000000000..358eb2025 --- /dev/null +++ b/agentlightning/utils/multi_provider_ai_client.py @@ -0,0 +1,135 @@ +""" +Multi-Provider Client +======================== +Async client that routes to different LLM providers based on model name. + +Usage: + - google-gemini-2.0-flash → Google API + - groq-llama-3.3-70b → Groq API + +Model name must be in "provider-model" format. + +## Usage +```python +from multi_provider_client import MultiProviderClient +client = MultiProviderClient() +# Use with APO +algo = agl.APO( + client, + gradient_model="google-gemini-2.0-flash", + apply_edit_model="groq-llama-3.3-70b-versatile", +) + + +""" + +import os +from openai import AsyncOpenAI + + +class MultiProviderClient: + """Async client that routes to different providers based on model name. + Model format: "provider-model_name" + Examples: + - google-gemini-2.0-flash + - groq-meta-llama/llama-4-maverick-17b-128e-instruct + - etc. + """ + + def __init__(self, custom_providers: dict[str, dict] | None = None): + """ + Args: + custom_providers: Additional providers. Format: + { + "provider_name": { + "api_key": "...", # or env var name + "base_url": "https://..." + } + } + """ + + self.clients = {} + + # Only create clients for providers with API keys + if os.getenv("GOOGLE_API_KEY"): + self.clients["google"] = AsyncOpenAI( + api_key=os.getenv("GOOGLE_API_KEY"), + base_url="https://generativelanguage.googleapis.com/v1beta/openai/") + + if os.getenv("GROQ_API_KEY"): + self.clients["groq"] = AsyncOpenAI( + api_key=os.getenv("GROQ_API_KEY"), + base_url="https://api.groq.com/openai/v1") + + if os.getenv("OPENAI_API_KEY"): + self.clients["openai"] = AsyncOpenAI( + api_key=os.getenv("OPENAI_API_KEY")) + + if os.getenv("AZURE_OPENAI_API_KEY") and os.getenv("AZURE_OPENAI_ENDPOINT"): + self.clients["azure"] = AsyncOpenAI( + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + base_url=os.getenv("AZURE_OPENAI_ENDPOINT")) + + if os.getenv("OPENROUTER_API_KEY"): + self.clients["openrouter"] = AsyncOpenAI( + api_key=os.getenv("OPENROUTER_API_KEY"), + base_url="https://openrouter.ai/api/v1") + + # Add custom providers + if custom_providers: + for name, config in custom_providers.items(): + api_key = config.get("api_key") or os.getenv(config.get("api_key_env", "")) + base_url = config.get("base_url") + self.clients[name] = AsyncOpenAI(api_key=api_key, base_url=base_url) + + + def _parse_model(self, model: str) -> tuple[str, str]: + """Parse model name into provider and actual model name. + + Args: + model: String in "provider-model_name" format + + Returns: + (provider, actual_model_name) tuple + """ + if "-" not in model: + raise ValueError(f"Model format must be 'provider-model_name': {model}") + + idx = model.find("-") + provider = model[:idx] + actual_model = model[idx + 1:] + + if provider not in self.clients: + for name in self.clients: + if model.startswith(name + "-"): + provider = name + actual_model = model[len(name) + 1:] + break + else: + raise ValueError(f"Unknown provider: {provider}. Supported: {list(self.clients.keys())}") + + return provider, actual_model + + + @property + def chat(self): + return self._ChatProxy(self) + + class _ChatProxy: + def __init__(self, parent): + self.parent = parent + + @property + def completions(self): + return self.parent._CompletionsProxy(self.parent) + + class _CompletionsProxy: + def __init__(self, parent): + self.parent = parent + + async def create(self, model: str, **kwargs): + provider, actual_model = self.parent._parse_model(model) + client = self.parent.clients[provider] + print("--- Multi Provider Client ---") + print(f"{provider.upper()}: {actual_model}") + return await client.chat.completions.create(model=actual_model, **kwargs) diff --git a/test.py b/test.py new file mode 100644 index 000000000..de63ac8ac --- /dev/null +++ b/test.py @@ -0,0 +1,8 @@ + +model="groq-meta-llama/llama-4-maverick-17b-128e-instruct" + +parts = model.split("-", 1) +provider = parts[0] +actual_model = parts[1] if len(parts) > 1 else model + +print(actual_model) \ No newline at end of file From 5fd14578d0ac8d8dd87c6c5f3d76435ec2e90976 Mon Sep 17 00:00:00 2001 From: Ekin Bozyel Date: Sun, 1 Feb 2026 21:20:54 +0300 Subject: [PATCH 2/7] delete test file --- test.py | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 test.py diff --git a/test.py b/test.py deleted file mode 100644 index de63ac8ac..000000000 --- a/test.py +++ /dev/null @@ -1,8 +0,0 @@ - -model="groq-meta-llama/llama-4-maverick-17b-128e-instruct" - -parts = model.split("-", 1) -provider = parts[0] -actual_model = parts[1] if len(parts) > 1 else model - -print(actual_model) \ No newline at end of file From 13fcd247f6ffe161486c05331dec3cccad171da4 Mon Sep 17 00:00:00 2001 From: Ekin Bozyel <50263592+john-fante@users.noreply.github.com> Date: Sun, 1 Feb 2026 21:44:31 +0300 Subject: [PATCH 3/7] Update agentlightning/utils/multi_provider_ai_client.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- agentlightning/utils/multi_provider_ai_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agentlightning/utils/multi_provider_ai_client.py b/agentlightning/utils/multi_provider_ai_client.py index 358eb2025..03f6facd6 100644 --- a/agentlightning/utils/multi_provider_ai_client.py +++ b/agentlightning/utils/multi_provider_ai_client.py @@ -11,7 +11,7 @@ ## Usage ```python -from multi_provider_client import MultiProviderClient +from agentlightning.utils.multi_provider_ai_client import MultiProviderClient client = MultiProviderClient() # Use with APO algo = agl.APO( From 13a96df2665c56a6962f6321ae0e464daf3c3c97 Mon Sep 17 00:00:00 2001 From: Ekin Bozyel <50263592+john-fante@users.noreply.github.com> Date: Sun, 1 Feb 2026 21:44:53 +0300 Subject: [PATCH 4/7] Update agentlightning/utils/multi_provider_ai_client.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- agentlightning/utils/multi_provider_ai_client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/agentlightning/utils/multi_provider_ai_client.py b/agentlightning/utils/multi_provider_ai_client.py index 03f6facd6..7e8ffba24 100644 --- a/agentlightning/utils/multi_provider_ai_client.py +++ b/agentlightning/utils/multi_provider_ai_client.py @@ -42,7 +42,8 @@ def __init__(self, custom_providers: dict[str, dict] | None = None): custom_providers: Additional providers. Format: { "provider_name": { - "api_key": "...", # or env var name + "api_key": "...", # literal API key value (optional) + "api_key_env": "ENV_VAR", # name of env var containing the API key (optional) "base_url": "https://..." } } From 9b16ce4a15f268ef62c684466800b52913251d59 Mon Sep 17 00:00:00 2001 From: Ekin Bozyel Date: Mon, 9 Feb 2026 14:56:41 +0300 Subject: [PATCH 5/7] refactor: rename client file and update to LiteLLM --- .../utils/multi_provider_ai_client.py | 136 ------------------ agentlightning/utils/multi_provider_client.py | 69 +++++++++ 2 files changed, 69 insertions(+), 136 deletions(-) delete mode 100644 agentlightning/utils/multi_provider_ai_client.py create mode 100644 agentlightning/utils/multi_provider_client.py diff --git a/agentlightning/utils/multi_provider_ai_client.py b/agentlightning/utils/multi_provider_ai_client.py deleted file mode 100644 index 7e8ffba24..000000000 --- a/agentlightning/utils/multi_provider_ai_client.py +++ /dev/null @@ -1,136 +0,0 @@ -""" -Multi-Provider Client -======================== -Async client that routes to different LLM providers based on model name. - -Usage: - - google-gemini-2.0-flash → Google API - - groq-llama-3.3-70b → Groq API - -Model name must be in "provider-model" format. - -## Usage -```python -from agentlightning.utils.multi_provider_ai_client import MultiProviderClient -client = MultiProviderClient() -# Use with APO -algo = agl.APO( - client, - gradient_model="google-gemini-2.0-flash", - apply_edit_model="groq-llama-3.3-70b-versatile", -) - - -""" - -import os -from openai import AsyncOpenAI - - -class MultiProviderClient: - """Async client that routes to different providers based on model name. - Model format: "provider-model_name" - Examples: - - google-gemini-2.0-flash - - groq-meta-llama/llama-4-maverick-17b-128e-instruct - - etc. - """ - - def __init__(self, custom_providers: dict[str, dict] | None = None): - """ - Args: - custom_providers: Additional providers. Format: - { - "provider_name": { - "api_key": "...", # literal API key value (optional) - "api_key_env": "ENV_VAR", # name of env var containing the API key (optional) - "base_url": "https://..." - } - } - """ - - self.clients = {} - - # Only create clients for providers with API keys - if os.getenv("GOOGLE_API_KEY"): - self.clients["google"] = AsyncOpenAI( - api_key=os.getenv("GOOGLE_API_KEY"), - base_url="https://generativelanguage.googleapis.com/v1beta/openai/") - - if os.getenv("GROQ_API_KEY"): - self.clients["groq"] = AsyncOpenAI( - api_key=os.getenv("GROQ_API_KEY"), - base_url="https://api.groq.com/openai/v1") - - if os.getenv("OPENAI_API_KEY"): - self.clients["openai"] = AsyncOpenAI( - api_key=os.getenv("OPENAI_API_KEY")) - - if os.getenv("AZURE_OPENAI_API_KEY") and os.getenv("AZURE_OPENAI_ENDPOINT"): - self.clients["azure"] = AsyncOpenAI( - api_key=os.getenv("AZURE_OPENAI_API_KEY"), - base_url=os.getenv("AZURE_OPENAI_ENDPOINT")) - - if os.getenv("OPENROUTER_API_KEY"): - self.clients["openrouter"] = AsyncOpenAI( - api_key=os.getenv("OPENROUTER_API_KEY"), - base_url="https://openrouter.ai/api/v1") - - # Add custom providers - if custom_providers: - for name, config in custom_providers.items(): - api_key = config.get("api_key") or os.getenv(config.get("api_key_env", "")) - base_url = config.get("base_url") - self.clients[name] = AsyncOpenAI(api_key=api_key, base_url=base_url) - - - def _parse_model(self, model: str) -> tuple[str, str]: - """Parse model name into provider and actual model name. - - Args: - model: String in "provider-model_name" format - - Returns: - (provider, actual_model_name) tuple - """ - if "-" not in model: - raise ValueError(f"Model format must be 'provider-model_name': {model}") - - idx = model.find("-") - provider = model[:idx] - actual_model = model[idx + 1:] - - if provider not in self.clients: - for name in self.clients: - if model.startswith(name + "-"): - provider = name - actual_model = model[len(name) + 1:] - break - else: - raise ValueError(f"Unknown provider: {provider}. Supported: {list(self.clients.keys())}") - - return provider, actual_model - - - @property - def chat(self): - return self._ChatProxy(self) - - class _ChatProxy: - def __init__(self, parent): - self.parent = parent - - @property - def completions(self): - return self.parent._CompletionsProxy(self.parent) - - class _CompletionsProxy: - def __init__(self, parent): - self.parent = parent - - async def create(self, model: str, **kwargs): - provider, actual_model = self.parent._parse_model(model) - client = self.parent.clients[provider] - print("--- Multi Provider Client ---") - print(f"{provider.upper()}: {actual_model}") - return await client.chat.completions.create(model=actual_model, **kwargs) diff --git a/agentlightning/utils/multi_provider_client.py b/agentlightning/utils/multi_provider_client.py new file mode 100644 index 000000000..57690d22f --- /dev/null +++ b/agentlightning/utils/multi_provider_client.py @@ -0,0 +1,69 @@ +""" +Multi-Provider Client (LiteLLM Version) +======================================== +Async client that routes to different LLM providers using LiteLLM. + +Usage: + - gemini/gemini-2.0-flash → Google API + - groq/llama-3.3-70b → Groq API + - ollama/llama3 → Local Ollama + - openai/ → OpenAI or Custom Base URL + +Model name should follow the standard LiteLLM "provider/model" format. + +## Usage +```python +from agentlightning.utils.multi_provider_ai_client import MultiProviderClient +client = MultiProviderClient() + +# Use with APO +algo = agl.APO( + client, + gradient_model="gemini/gemini-2.0-flash", + apply_edit_model="groq/llama-3.3-70b-versatile", +) + +""" + +from litellm import acompletion + +class MultiProviderClient: + """Async client that routes to different providers using LiteLLM. + Uses standard LiteLLM 'provider/model' format. + """ + + def __init__(self, **kwargs): + """ + Initializes the client. LiteLLM automatically picks up API keys + from environment variables (e.g., GOOGLE_API_KEY, GROQ_API_KEY). + """ + pass + + @property + def chat(self): + return self._ChatProxy(self) + + class _ChatProxy: + def __init__(self, parent): + self.parent = parent + + @property + def completions(self): + return self.parent._CompletionsProxy(self.parent) + + class _CompletionsProxy: + def __init__(self, parent): + self.parent = parent + + async def create(self, model: str, **kwargs): + """ + Passes the request directly to LiteLLM for routing. + + Args: + model: String in "provider/model_name" format. + **kwargs: Additional arguments for the completion call. + """ + print("--- Multi Provider Client (LiteLLM) ---") + print(f"Routing to: {model}") + + return await acompletion(model=model, **kwargs) \ No newline at end of file From 610651fd1e9ae0b0d4a61788801645e978a6e146 Mon Sep 17 00:00:00 2001 From: Ekin Bozyel Date: Wed, 11 Feb 2026 00:22:14 +0300 Subject: [PATCH 6/7] feat(apo): add MultiProviderClient for hybrid-model optimization - Implement MultiProviderClient to support routing different APO stages to different LLM providers. - Add showcasing Gemini and Groq hybrid workflow. - Update with documentation for the new hybrid sample. - Improve static type safety in APO samples to resolve Pylance warnings. --- agentlightning/utils/multi_provider_client.py | 30 +--- examples/apo/README.md | 7 +- examples/apo/apo_multi_provider.py | 138 ++++++++++++++++++ 3 files changed, 146 insertions(+), 29 deletions(-) create mode 100644 examples/apo/apo_multi_provider.py diff --git a/agentlightning/utils/multi_provider_client.py b/agentlightning/utils/multi_provider_client.py index 57690d22f..48d424925 100644 --- a/agentlightning/utils/multi_provider_client.py +++ b/agentlightning/utils/multi_provider_client.py @@ -13,31 +13,19 @@ ## Usage ```python -from agentlightning.utils.multi_provider_ai_client import MultiProviderClient +from agentlightning.utils.multi_provider_client import MultiProviderClient client = MultiProviderClient() -# Use with APO -algo = agl.APO( - client, - gradient_model="gemini/gemini-2.0-flash", - apply_edit_model="groq/llama-3.3-70b-versatile", -) - +``` """ from litellm import acompletion class MultiProviderClient: - """Async client that routes to different providers using LiteLLM. - Uses standard LiteLLM 'provider/model' format. - """ + """Async client that routes to different providers using LiteLLM.""" def __init__(self, **kwargs): - """ - Initializes the client. LiteLLM automatically picks up API keys - from environment variables (e.g., GOOGLE_API_KEY, GROQ_API_KEY). - """ - pass + print("--- Multi Provider Client (LiteLLM) Initialized ---") @property def chat(self): @@ -56,14 +44,4 @@ def __init__(self, parent): self.parent = parent async def create(self, model: str, **kwargs): - """ - Passes the request directly to LiteLLM for routing. - - Args: - model: String in "provider/model_name" format. - **kwargs: Additional arguments for the completion call. - """ - print("--- Multi Provider Client (LiteLLM) ---") - print(f"Routing to: {model}") - return await acompletion(model=model, **kwargs) \ No newline at end of file diff --git a/examples/apo/README.md b/examples/apo/README.md index a3f647853..4dd99ed08 100644 --- a/examples/apo/README.md +++ b/examples/apo/README.md @@ -19,11 +19,12 @@ Follow the [installation guide](../../docs/tutorials/installation.md) to install | `room_selector.py` | Room booking agent implementation using function calling | | `room_selector_apo.py` | Training script using the built-in APO algorithm to optimize prompts | | `room_tasks.jsonl` | Dataset with room booking scenarios and expected selections | -| `apo_custom_algorithm.py` | Tutorial on creating custom algorithms (runnable as algo or runner) | +| `apo_custom_algorithm.py` | Tutorial on creating custom algorithms | | `apo_custom_algorithm_trainer.py` | Shows how to integrate custom algorithms into the Trainer | | `apo_debug.py` | Tutorial demonstrating various agent debugging techniques | -| `legacy_apo_client.py` | Deprecated APO client implementation compatible with Agent-lightning v0.1.x | -| `legacy_apo_server.py` | Deprecated APO server implementation compatible with Agent-lightning v0.1.x | +| `apo_multi_provider.py` | Hybrid optimization sample using multiple LLM backends | +| `legacy_apo_client.py` | Deprecated APO client implementation compatible with v0.1.x | +| `legacy_apo_server.py` | Deprecated APO server implementation compatible with v0.1.x | ## Sample 1: Using Built-in APO Algorithm diff --git a/examples/apo/apo_multi_provider.py b/examples/apo/apo_multi_provider.py new file mode 100644 index 000000000..4acb2d070 --- /dev/null +++ b/examples/apo/apo_multi_provider.py @@ -0,0 +1,138 @@ +""" +This sample code demonstrates how to use the MultiProviderClient with the APO algorithm +to tune mathematical reasoning prompts using a hybrid model setup. +""" + +import logging +import re +import asyncio +import multiprocessing +from typing import Tuple, cast, Dict, Any, List + +from dotenv import load_dotenv +load_dotenv() +import agentlightning as agl +from agentlightning import Trainer, setup_logging, PromptTemplate +from agentlightning.adapter import TraceToMessages +from agentlightning.algorithm.apo import APO +from agentlightning.types import Dataset +from litellm import completion +from agentlightning.utils.multi_provider_client import MultiProviderClient + + +# --- 1. Dataset Logic --- +def load_math_tasks() -> List[Dict[str, str]]: + """Small mock GSM8k-style dataset.""" + return [ + {"question": "If I have 3 apples and buy 2 more, how many do I have?", "expected": "5"}, + {"question": "A train travels 60 miles in 1 hour. How far in 3 hours?", "expected": "180"}, + {"question": "What is the square root of 144?", "expected": "12"}, + {"question": "If a shirt costs $20 and is 10% off, what is the price?", "expected": "18"}, + ] + +def load_train_val_dataset() -> Tuple[Dataset[Dict[str, str]], Dataset[Dict[str, str]]]: + dataset_full = load_math_tasks() + train_split = len(dataset_full) // 2 + # Use list() and cast to satisfy Pylance's SupportsIndex/slice checks + dataset_train = cast(Dataset[Dict[str, str]], list(dataset_full[:train_split])) + dataset_val = cast(Dataset[Dict[str, str]], list(dataset_full[train_split:])) + return dataset_train, dataset_val + +# --- 2. Agent Logic --- +class MathAgent(agl.LitAgent): + def __init__(self): + super().__init__() + + def rollout(self, task: Any, resources: Dict[str, Any], rollout: Any) -> float: + # Pylance fix: Explicitly cast task to Dict + t = cast(Dict[str, str], task) + prompt_template: PromptTemplate = resources.get("prompt_template") # type: ignore + + # Ensure template access is type-safe + template_str = getattr(prompt_template, "template", str(prompt_template)) + prompt = template_str.format(question=t["question"]) + + # Direct LiteLLM call + response = completion( + model="gemini/gemini-2.0-flash", + messages=[{"role": "user", "content": prompt}] + ) + answer = str(response.choices[0].message.content) + + # Reward: Numerical exact match check + pred_nums = re.findall(r"[-+]?\d*\.\d+|\d+", answer.split("Answer:")[-1]) + reward = 1.0 if pred_nums and pred_nums[-1] == t["expected"] else 0.0 + + agl.emit_reward(reward) + return reward + +# --- 3. Logging & Main --- +def setup_apo_logger(file_path: str = "apo_math.log") -> None: + file_handler = logging.FileHandler(file_path) + file_handler.setLevel(logging.INFO) + formatter = logging.Formatter("%(asctime)s [%(levelname)s] (%(name)s) %(message)s") + file_handler.setFormatter(formatter) + logging.getLogger("agentlightning.algorithm.apo").addHandler(file_handler) + +def main() -> None: + setup_logging() + setup_apo_logger() + + multi_client = MultiProviderClient() + + initial_prompt_str = "Solve: {question}" + + algo = APO[Dict[str, str]]( + multi_client, + gradient_model="gemini/gemini-2.0-flash", + apply_edit_model="groq/llama-3.3-70b-versatile", + val_batch_size=2, + gradient_batch_size=2, + beam_width=1, + branch_factor=1, + beam_rounds=1, + ) + + trainer = Trainer( + algorithm=algo, + n_runners=2, + initial_resources={ + "prompt_template": PromptTemplate(template=initial_prompt_str, engine="f-string") + }, + adapter=TraceToMessages(), + ) + + dataset_train, dataset_val = load_train_val_dataset() + agent = MathAgent() + + print("\n" + "="*60) + print("🚀 HYBRID APO OPTIMIZATION STARTING") + print("-" * 60) + + trainer.fit(agent=agent, train_dataset=dataset_train, val_dataset=dataset_val) + + # Print Final Prompt from the store + print("\n" + "="*60) + print("✅ OPTIMIZATION COMPLETE") + print("-" * 60) + print(f"INITIAL PROMPT:\n{initial_prompt_str}") + + + # Accessing the latest optimized prompt from the trainer store + try: + latest_resources = asyncio.run(trainer.store.query_resources()) + if latest_resources: + final_res = latest_resources[-1].resources.get("prompt_template") + final_prompt = getattr(final_res, "template", str(final_res)) + print(f"FINAL OPTIMIZED PROMPT:\n{final_prompt}") + except Exception as e: + print(f"Optimization finished. Check apo_math.log for detailed iteration results. Error: {e}") + + print("="*60 + "\n") + +if __name__ == "__main__": + try: + multiprocessing.set_start_method("fork", force=True) + except RuntimeError: + pass + main() \ No newline at end of file From 0e95e372b1dde7490591c6ea9029fc03cb46b2f2 Mon Sep 17 00:00:00 2001 From: Ekin Bozyel Date: Sun, 1 Mar 2026 15:36:14 +0300 Subject: [PATCH 7/7] refactor(apo): move LiteLLM integration into APO and remove extra client - add use_litellm backend path in APO for chat completion calls\n- remove standalone MultiProviderClient utility to reduce confusion\n- update APO multi-provider example to use APO(use_litellm=True)\n- add APO test coverage for LiteLLM backend path\n- apply minimal typing fixes in apo_multi_provider example --- .gitignore | 1 + agentlightning/algorithm/apo/apo.py | 42 +++++++++++++++-- agentlightning/utils/multi_provider_client.py | 47 ------------------- examples/apo/apo_custom_algorithm.py | 2 +- examples/apo/apo_multi_provider.py | 19 ++++---- tests/algorithm/test_apo.py | 20 ++++++++ 6 files changed, 71 insertions(+), 60 deletions(-) delete mode 100644 agentlightning/utils/multi_provider_client.py diff --git a/.gitignore b/.gitignore index c89af90d7..82ded808b 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ meta-llama/** **/debug/**/*.json requirements-freeze*.txt /playground +.env # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/agentlightning/algorithm/apo/apo.py b/agentlightning/algorithm/apo/apo.py index 894ea8be2..8615da8bd 100644 --- a/agentlightning/algorithm/apo/apo.py +++ b/agentlightning/algorithm/apo/apo.py @@ -35,6 +35,11 @@ import poml from openai import AsyncOpenAI +try: + from litellm import acompletion as litellm_acompletion +except ImportError: + litellm_acompletion = None + from agentlightning.adapter.messages import TraceToMessages from agentlightning.algorithm.base import Algorithm from agentlightning.algorithm.utils import batch_iter_over_dataset, with_llm_proxy, with_store @@ -100,8 +105,9 @@ class APO(Algorithm, Generic[T_task]): def __init__( self, - async_openai_client: AsyncOpenAI, + async_openai_client: Optional[AsyncOpenAI] = None, *, + use_litellm: bool = False, gradient_model: str = "gpt-5-mini", apply_edit_model: str = "gpt-4.1-mini", diversity_temperature: float = 1.0, @@ -122,6 +128,8 @@ def __init__( Args: async_openai_client: AsyncOpenAI client for making LLM API calls. + Optional when ``use_litellm=True``. + use_litellm: If True, uses ``litellm.acompletion`` directly for LLM calls. gradient_model: Model name for computing textual gradients (critiques). apply_edit_model: Model name for applying edits based on critiques. diversity_temperature: Temperature parameter for LLM calls to control diversity. @@ -137,7 +145,11 @@ def __init__( gradient_prompt_files: Prompt templates used to compute textual gradients (critiques). apply_edit_prompt_files: Prompt templates used to apply edits based on critiques. """ + if use_litellm and litellm_acompletion is None: + raise ImportError("litellm is not installed but use_litellm=True was provided.") + self.async_openai_client = async_openai_client + self.use_litellm = use_litellm self.gradient_model = gradient_model self.apply_edit_model = apply_edit_model self.diversity_temperature = diversity_temperature @@ -159,6 +171,30 @@ def __init__( self._poml_trace = _poml_trace + async def _chat_completion_create( + self, + *, + model: str, + messages: Any, + temperature: float, + ) -> Any: + if self.use_litellm: + assert litellm_acompletion is not None + return await litellm_acompletion( + model=model, + messages=messages, + temperature=temperature, + ) + + if self.async_openai_client is None: + raise ValueError("async_openai_client must be provided when use_litellm=False") + + return await self.async_openai_client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + ) + def _create_versioned_prompt( self, prompt_template: PromptTemplate, @@ -307,7 +343,7 @@ async def compute_textual_gradient( f"Gradient computed with {self.gradient_model} prompt: {tg_msg}", prefix=prefix, ) - critique_response = await self.async_openai_client.chat.completions.create( + critique_response = await self._chat_completion_create( model=self.gradient_model, messages=tg_msg["messages"], # type: ignore temperature=self.diversity_temperature, @@ -373,7 +409,7 @@ async def textual_gradient_and_apply_edit( format="openai_chat", ) - ae_response = await self.async_openai_client.chat.completions.create( + ae_response = await self._chat_completion_create( model=self.apply_edit_model, messages=ae_msg["messages"], # type: ignore temperature=self.diversity_temperature, diff --git a/agentlightning/utils/multi_provider_client.py b/agentlightning/utils/multi_provider_client.py deleted file mode 100644 index 48d424925..000000000 --- a/agentlightning/utils/multi_provider_client.py +++ /dev/null @@ -1,47 +0,0 @@ -""" -Multi-Provider Client (LiteLLM Version) -======================================== -Async client that routes to different LLM providers using LiteLLM. - -Usage: - - gemini/gemini-2.0-flash → Google API - - groq/llama-3.3-70b → Groq API - - ollama/llama3 → Local Ollama - - openai/ → OpenAI or Custom Base URL - -Model name should follow the standard LiteLLM "provider/model" format. - -## Usage -```python -from agentlightning.utils.multi_provider_client import MultiProviderClient -client = MultiProviderClient() - -``` -""" - -from litellm import acompletion - -class MultiProviderClient: - """Async client that routes to different providers using LiteLLM.""" - - def __init__(self, **kwargs): - print("--- Multi Provider Client (LiteLLM) Initialized ---") - - @property - def chat(self): - return self._ChatProxy(self) - - class _ChatProxy: - def __init__(self, parent): - self.parent = parent - - @property - def completions(self): - return self.parent._CompletionsProxy(self.parent) - - class _CompletionsProxy: - def __init__(self, parent): - self.parent = parent - - async def create(self, model: str, **kwargs): - return await acompletion(model=model, **kwargs) \ No newline at end of file diff --git a/examples/apo/apo_custom_algorithm.py b/examples/apo/apo_custom_algorithm.py index 45ba529ec..c92ae42eb 100644 --- a/examples/apo/apo_custom_algorithm.py +++ b/examples/apo/apo_custom_algorithm.py @@ -89,7 +89,7 @@ async def apo_algorithm(*, store: agl.LightningStore): await log_llm_span(spans) # 4. The algorithm records the final reward for sorting - final_reward = agl.find_final_reward(spans) + final_reward = agl.find_final_reward(spans) # type: ignore assert final_reward is not None, "Expected a final reward from the client." console.print(f"{algo_marker} Final reward: {final_reward}") prompt_and_rewards.append((prompt, final_reward)) diff --git a/examples/apo/apo_multi_provider.py b/examples/apo/apo_multi_provider.py index 4acb2d070..94d7eebb8 100644 --- a/examples/apo/apo_multi_provider.py +++ b/examples/apo/apo_multi_provider.py @@ -1,5 +1,5 @@ """ -This sample code demonstrates how to use the MultiProviderClient with the APO algorithm +This sample code demonstrates how to use APO's built-in LiteLLM mode to tune mathematical reasoning prompts using a hybrid model setup. """ @@ -7,7 +7,7 @@ import re import asyncio import multiprocessing -from typing import Tuple, cast, Dict, Any, List +from typing import Tuple, cast, Dict, Any, List, Callable from dotenv import load_dotenv load_dotenv() @@ -16,8 +16,9 @@ from agentlightning.adapter import TraceToMessages from agentlightning.algorithm.apo import APO from agentlightning.types import Dataset -from litellm import completion -from agentlightning.utils.multi_provider_client import MultiProviderClient +import litellm + +completion: Callable[..., Any] = cast(Callable[..., Any], getattr(litellm, "completion")) # --- 1. Dataset Logic --- @@ -39,7 +40,7 @@ def load_train_val_dataset() -> Tuple[Dataset[Dict[str, str]], Dataset[Dict[str, return dataset_train, dataset_val # --- 2. Agent Logic --- -class MathAgent(agl.LitAgent): +class MathAgent(agl.LitAgent[Dict[str, str]]): def __init__(self): super().__init__() @@ -57,7 +58,9 @@ def rollout(self, task: Any, resources: Dict[str, Any], rollout: Any) -> float: model="gemini/gemini-2.0-flash", messages=[{"role": "user", "content": prompt}] ) - answer = str(response.choices[0].message.content) + response_obj = response + content = response_obj.choices[0].message.content + answer = content if isinstance(content, str) else "" # Reward: Numerical exact match check pred_nums = re.findall(r"[-+]?\d*\.\d+|\d+", answer.split("Answer:")[-1]) @@ -78,12 +81,10 @@ def main() -> None: setup_logging() setup_apo_logger() - multi_client = MultiProviderClient() - initial_prompt_str = "Solve: {question}" algo = APO[Dict[str, str]]( - multi_client, + use_litellm=True, gradient_model="gemini/gemini-2.0-flash", apply_edit_model="groq/llama-3.3-70b-versatile", val_batch_size=2, diff --git a/tests/algorithm/test_apo.py b/tests/algorithm/test_apo.py index 90d8efd0d..78c3e5611 100644 --- a/tests/algorithm/test_apo.py +++ b/tests/algorithm/test_apo.py @@ -311,6 +311,26 @@ async def test_compute_textual_gradient_uses_all_rollouts_when_insufficient(monk assert result == "critique" +@pytest.mark.asyncio +async def test_compute_textual_gradient_uses_litellm_backend(monkeypatch: pytest.MonkeyPatch) -> None: + litellm_mock = AsyncMock(return_value=make_completion("critique")) + monkeypatch.setattr(apo_module, "litellm_acompletion", litellm_mock) + monkeypatch.setattr(apo_module.random, "choice", lambda seq: seq[0]) # type: ignore + + apo = APO[Any](use_litellm=True, gradient_model="test-gradient-model", gradient_batch_size=1) + versioned_prompt = apo._create_versioned_prompt(PromptTemplate(template="prompt", engine="f-string")) + rollouts: List[RolloutResultForAPO] = [ + RolloutResultForAPO(status="succeeded", final_reward=1.0, spans=[], messages=[]) + ] + + result = await apo.compute_textual_gradient(versioned_prompt, rollouts) + + assert result == "critique" + litellm_mock.assert_awaited_once() + call_kwargs = litellm_mock.await_args.kwargs # type: ignore + assert call_kwargs["model"] == "test-gradient-model" + + @pytest.mark.asyncio async def test_textual_gradient_and_apply_edit_returns_new_prompt(monkeypatch: pytest.MonkeyPatch) -> None: # Use two separate mocks for gradient and edit calls