diff --git a/examples/agents/simple_chat.py b/examples/agents/simple_chat.py index a76a4357..70570695 100644 --- a/examples/agents/simple_chat.py +++ b/examples/agents/simple_chat.py @@ -9,7 +9,7 @@ from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger from termcolor import colored -from .utils import check_model_is_available, get_any_available_model +from .utils import check_model_is_available, get_any_available_chat_model def main(host: str, port: int, model_id: str | None = None): @@ -27,14 +27,8 @@ def main(host: str, port: int, model_id: str | None = None): provider_data={"tavily_search_api_key": os.getenv("TAVILY_SEARCH_API_KEY")}, ) - available_shields = [shield.identifier for shield in client.shields.list()] - if not available_shields: - print(colored("No available shields. Disabling safety.", "yellow")) - else: - print(f"Available shields found: {available_shields}") - if model_id is None: - model_id = get_any_available_model(client) + model_id = get_any_available_chat_model(client) if model_id is None: return else: @@ -47,10 +41,8 @@ def main(host: str, port: int, model_id: str | None = None): client, model=model_id, instructions="", - tools=["builtin::websearch"], - input_shields=available_shields, - output_shields=available_shields, - enable_session_persistence=False, + # OpenAI Responses tool schema requires a type discriminator. + tools=[{"type": "web_search"}], ) user_prompts = [ "Hello", @@ -65,8 +57,8 @@ def main(host: str, port: int, model_id: str | None = None): session_id=session_id, ) - for log in AgentEventLogger().log(response): - log.print() + for printable in AgentEventLogger().log(response): + print(printable, end="", flush=True) if __name__ == "__main__": diff --git a/examples/agents/utils.py b/examples/agents/utils.py index 298be9a0..8878d81e 100644 --- a/examples/agents/utils.py +++ b/examples/agents/utils.py @@ -2,11 +2,36 @@ from termcolor import colored +def _get_model_type(model) -> str | None: + for metadata_attr in ("custom_metadata", "metadata"): + metadata = getattr(model, metadata_attr, None) + if isinstance(metadata, dict): + value = metadata.get("model_type") or metadata.get("type") + if isinstance(value, str): + return value + return None + + +def _is_llm_model(model) -> bool: + model_type = _get_model_type(model) + # If the client schema doesn't expose type fields, assume LLM. + return model_type is None or model_type == "llm" + + +def _get_model_id(model) -> str | None: + for attr in ("identifier", "model_id", "id", "name"): + value = getattr(model, attr, None) + if isinstance(value, str): + return value + return None + + def check_model_is_available(client: LlamaStackClient, model: str): available_models = [ - model.identifier + model_id for model in client.models.list() - if model.model_type == "llm" and "guard" not in model.identifier + for model_id in [_get_model_id(model)] + if model_id and _is_llm_model(model) and "guard" not in model_id ] if model not in available_models: @@ -23,12 +48,44 @@ def check_model_is_available(client: LlamaStackClient, model: str): def get_any_available_model(client: LlamaStackClient): available_models = [ - model.identifier + model_id for model in client.models.list() - if model.model_type == "llm" and "guard" not in model.identifier + for model_id in [_get_model_id(model)] + if model_id and _is_llm_model(model) and "guard" not in model_id ] if not available_models: print(colored("No available models.", "red")) return None return available_models[0] + + +def can_model_chat(client: LlamaStackClient, model_id: str) -> bool: + # Lightweight probe to ensure the model supports chat completions. + try: + client.chat.completions.create( + model=model_id, + messages=[{"role": "user", "content": "ping"}], + max_tokens=1, + ) + except Exception: + return False + return True + +def get_any_available_chat_model(client: LlamaStackClient): + available_models = [ + model_id + for model in client.models.list() + for model_id in [_get_model_id(model)] + if model_id and _is_llm_model(model) and "guard" not in model_id + ] + if not available_models: + print(colored("No available models.", "red")) + return None + + for model_id in available_models: + if can_model_chat(client, model_id): + return model_id + + print(colored("No available chat-capable models.", "red")) + return None