From fccd1dd73a597c4ee4ac2f7d946f6565d5838371 Mon Sep 17 00:00:00 2001 From: Enoch Danijel <90690698+Danijel-Enoch@users.noreply.github.com> Date: Mon, 26 Jan 2026 15:47:50 +0100 Subject: [PATCH] feat: Add Groq and OpenRouter LLM providers, update the unified client and configuration, and include new documentation. --- AGENTS.md | 151 ++++++++++++++++ commands/graph.py | 18 ++ config.yaml.example | 22 +++ llm/groq_provider.py | 208 ++++++++++++++++++++++ llm/openrouter_provider.py | 195 +++++++++++++++++++++ llm/unified_client.py | 130 ++++++++------ tests/test_provider_token_tracking.py | 240 ++++++++++++++++---------- 7 files changed, 818 insertions(+), 146 deletions(-) create mode 100644 AGENTS.md create mode 100644 llm/groq_provider.py create mode 100644 llm/openrouter_provider.py diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..232b3a7 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,151 @@ +# Hound - Agent Development Guide + +This file contains build commands, code style guidelines, and development conventions for agents working on this codebase. + +## Build/Lint/Test Commands + +### Core Development Commands +```bash +# Install dependencies +pip install -r requirements.txt +pip install -e .[dev] # Install with dev dependencies + +# Code formatting and linting +black . # Format code (line-length: 100) +ruff check . # Lint code (line-length: 120, see ruff.toml) +ruff check . --fix # Auto-fix linting issues +mypy . # Type checking (strict mode) + +# Testing +pytest # Run all tests +pytest -v # Verbose test output +pytest -x # Stop on first failure +pytest -m "not slow" # Skip slow tests +pytest -m "not integration" # Skip tests requiring external services + +# Single test execution +pytest tests/test_unified_client.py::TestUnifiedClient::test_init_default_provider +pytest tests/ -k "test_function_name" # Run tests matching pattern +pytest tests/test_unified_client.py -k "test_provider_switching" + +# Coverage testing +pytest --cov=. --cov-report=html --cov-report=term + +# Development server +python chatbot/run.py # Start telemetry UI (default: http://127.0.0.1:5280) +``` + +### Hound Application Commands +```bash +# Main entry point (from hound/ directory) +./hound.py --help # Show all available commands + +# Project management +./hound.py project create +./hound.py project ls +./hound.py project info + +# Graph building +./hound.py graph build --auto --files "src/A.sol,src/B.sol" +./hound.py graph refine SystemArchitecture --iterations 2 + +# Agent operations +./hound.py agent audit --mode sweep +./hound.py agent audit --mode intuition --time-limit 300 +./hound.py agent investigate "query" + +# Finalization and reporting +./hound.py finalize +./hound.py report --output /path/to/report.html +``` + +## Code Style Guidelines + +### Imports & Formatting +- Use `ruff` for import sorting; `black` for formatting (line-length: 100) +- Ruff line-length: 120; excludes `chatbot/static` and build artifacts +- Import order: stdlib → third-party → first-party (`llm`, `analysis`, etc.) +- Combine as imports: `from typing import Any, TypeVar` +- Type hints use `|` syntax (Python 3.10+): `str | None`, `dict[str, Any]` + +### Type System +- Use `from __future__ import annotations` for forward references +- Strict mypy enabled; use Pydantic BaseModel for data validation +- Generic types: `T = TypeVar('T', bound=BaseModel)` + +### Naming Conventions +- Classes: `PascalCase` (e.g., `UnifiedLLMClient`, `AgentParameters`) +- Functions/variables: `snake_case`; Private: `_leading_underscore` +- Files: `snake_case.py`; Constants: `UPPER_SNAKE_CASE` + +### Error Handling & Async +- Try/except around external API calls with structured error messages +- Use `async/await` for I/O; HTTP via `httpx` +- Logging via `debug_logger` when available + +## Architecture Patterns + +### Provider System +- Base class: `BaseLLMProvider` in `llm/base_provider.py` +- Implementations: `OpenAIProvider`, `GeminiProvider`, `AnthropicProvider`, etc. +- Unified client: `UnifiedLLMClient` handles provider switching +- Configuration-driven provider selection per model profile + +### Data Models +- Pydantic models for structured data (`schemas.py`) +- Validation and serialization built-in +- Nested models for complex data structures +- JSON schema generation for LLM interactions + +### File Organization +``` +hound/ +├── llm/ # LLM provider implementations +├── analysis/ # Core analysis logic and agents +├── commands/ # CLI command implementations +├── utils/ # Shared utilities (config, JSON, CLI) +├── ingest/ # Code parsing and bundling +├── visualization/ # Graph visualization +├── tests/ # Test suite +└── chatbot/ # Telemetry UI +``` + +### Testing Conventions +- Test files: `test_*.py` in `tests/` directory +- Test classes: `TestClassName(unittest.TestCase)` +- Test methods: `test_specific_behavior` +- Fixtures: Use `conftest.py` for shared test setup +- Mocking: Mock external LLM APIs in tests +- Markers: `@pytest.mark.slow`, `@pytest.mark.integration` + +### Configuration Management +- YAML configuration via `config.yaml` +- Environment variable fallbacks +- Config loader with priority ordering (see `utils/config_loader.py`) +- Profile-based model configuration +- API key management via environment variables + +### Session Management +- Session-based audit tracking with per-session planning +- Persistent storage in `~/.hound/projects/`; resume via session IDs +- Token usage tracking per session + +## Development Workflow + +### Adding New LLM Providers +1. Create provider class inheriting from `BaseLLMProvider` +2. Implement required abstract methods +3. Add provider to `UnifiedLLMClient` factory +4. Add tests in `tests/test_provider_token_tracking.py` + +### CLI Command Development +1. Create command file in `commands/` using Typer +2. Follow existing patterns; add help text +3. Test with `./hound.py command --help` + +### Debugging and Troubleshooting +- Enable debug mode: `--debug` saves LLM interactions to `.hound_debug/` +- Use telemetry UI: `python chatbot/run.py` with `--telemetry` +- Check token usage via `TokenTracker` + +This codebase prioritizes modular design, type safety, and comprehensive testing. \ No newline at end of file diff --git a/commands/graph.py b/commands/graph.py index 6c3d892..ff53b66 100644 --- a/commands/graph.py +++ b/commands/graph.py @@ -483,6 +483,9 @@ def builder_callback(info): console.print(table) console.print(f"\n [bold]Total:[/bold] {total_nodes} nodes, {total_edges} edges") + # Call the new function to write knowledge_graphs.json + _write_knowledge_graphs_metadata(graphs_dir) + # Step 3: Visualization if visualize: console.print("\n[bold]Step 3:[/bold] Visualization") @@ -518,6 +521,21 @@ def builder_callback(info): raise +def _write_knowledge_graphs_metadata(graphs_dir: Path): + """Generates knowledge_graphs.json from graph_*.json files.""" + knowledge_graphs_path = graphs_dir / "knowledge_graphs.json" + graph_files = list(graphs_dir.glob("graph_*.json")) + + graphs_dict = {} + for graph_file in graph_files: + graph_name = graph_file.stem.replace('graph_', '') + # Store absolute path + graphs_dict[graph_name] = str(graph_file.resolve()) + + with open(knowledge_graphs_path, 'w') as f: + json.dump({'graphs': graphs_dict}, f, indent=2) + + def custom( project_id: str, graph_spec_text: str, diff --git a/config.yaml.example b/config.yaml.example index 0fa16dc..798948f 100644 --- a/config.yaml.example +++ b/config.yaml.example @@ -24,6 +24,16 @@ anthropic: xai: api_key_env: XAI_API_KEY +openrouter: + api_key_env: OPENROUTER_API_KEY + # Optional: HTTP-Referer header for OpenRouter rankings + # referer: https://your-site.com + # Optional: X-Title header for OpenRouter dashboard + # app_title: Hound + +groq: + api_key_env: GROQ_API_KEY + # Model configuration with context limits models: # Graph building model - needs large context @@ -75,6 +85,18 @@ models: model: gpt-4o # No max_context specified - will use global default + # Example: OpenRouter model profile (uncomment to use) + # openrouter_agent: + # provider: openrouter + # model: meta-llama/llama-3.1-405b-instruct + # max_context: 128000 + + # Example: Groq model profile for fast inference (uncomment to use) + # groq_fast: + # provider: groq + # model: llama-3.3-70b-versatile + # max_context: 128000 + # Global context settings (used when model doesn't specify max_context) context: max_tokens: 256000 # Default for models without specific max_context diff --git a/llm/groq_provider.py b/llm/groq_provider.py new file mode 100644 index 0000000..d98a9e6 --- /dev/null +++ b/llm/groq_provider.py @@ -0,0 +1,208 @@ +"""Groq provider implementation.""" + +from __future__ import annotations + +import os +import random +import time +from typing import Any, TypeVar + +from openai import OpenAI +from pydantic import BaseModel + +from .base_provider import BaseLLMProvider + +T = TypeVar("T", bound=BaseModel) + + +class GroqProvider(BaseLLMProvider): + """Groq API provider implementation. + + Groq provides high-speed inference for open-source LLMs (Llama, Mixtral, etc.) + through an OpenAI-compatible API. + """ + + def __init__( + self, + config: dict[str, Any], + model_name: str, + timeout: int = 120, + retries: int = 3, + backoff_min: float = 2.0, + backoff_max: float = 8.0, + **kwargs, + ): + """Initialize Groq provider.""" + self.config = config + self.model_name = model_name + self.timeout = timeout + self.retries = retries + self.backoff_min = backoff_min + self.backoff_max = backoff_max + + # Verbose logging toggle (suppress request logs by default) + logging_cfg = config.get("logging", {}) if isinstance(config, dict) else {} + env_verbose = os.environ.get("HOUND_LLM_VERBOSE", "").lower() in {"1", "true", "yes", "on"} + self.verbose = kwargs.get( + "verbose", bool(logging_cfg.get("llm_verbose", False) or env_verbose) + ) + self._last_token_usage = None + + # Get API key from environment + groq_cfg = config.get("groq", {}) if isinstance(config, dict) else {} + api_key_env = groq_cfg.get("api_key_env", "GROQ_API_KEY") + api_key = os.environ.get(api_key_env) + if not api_key: + raise ValueError(f"API key not found in environment variable: {api_key_env}") + + # Get base URL from config or use default Groq endpoint + base_url = groq_cfg.get("base_url", "https://api.groq.com/openai/v1") + + self.client = OpenAI(api_key=api_key, base_url=base_url) + + def parse(self, *, system: str, user: str, schema: type[T]) -> T: + """Make a structured call using Groq's API. + + Uses JSON mode with schema instructions for reliable structured output. + """ + # Log request details + request_chars = len(system) + len(user) + if self.verbose: + print("\n[Groq Request]") + print(f" Model: {self.model_name}") + print(f" Schema: {schema.__name__}") + print(f" Total prompt: {request_chars:,} chars (~{request_chars//4:,} tokens)") + + last_err = None + + for attempt in range(self.retries): + try: + attempt_start = time.time() + if self.verbose: + print(f" Attempt {attempt + 1}/{self.retries}...") + + # Try structured output first (Groq supports it for some models) + try: + messages = [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ] + + completion = self.client.beta.chat.completions.parse( + model=self.model_name, + messages=messages, + response_format=schema, + timeout=self.timeout, + ) + + # Get the parsed response if available + if completion.choices[0].message.parsed: + parsed_result = completion.choices[0].message.parsed + elif completion.choices[0].message.refusal: + raise RuntimeError( + f"Model refused: {completion.choices[0].message.refusal}" + ) + else: + # Manual parsing fallback + json_str = completion.choices[0].message.content + parsed_result = schema.model_validate_json(json_str) + + except Exception as structured_err: + # Fallback to JSON mode with schema instruction + if self.verbose: + print( + f" Structured output failed, falling back to JSON mode: {structured_err}" + ) + + json_instruction = f"\nRespond with valid JSON matching this schema: {schema.model_json_schema()}" + enhanced_system = system + json_instruction + + enhanced_messages = [ + {"role": "system", "content": enhanced_system}, + {"role": "user", "content": user}, + ] + + completion = self.client.chat.completions.create( + model=self.model_name, + messages=enhanced_messages, + timeout=self.timeout, + response_format={"type": "json_object"}, + ) + + # Parse JSON response manually + json_str = completion.choices[0].message.content + parsed_result = schema.model_validate_json(json_str) + + # Log response details + response_time = time.time() - attempt_start + response_content = completion.choices[0].message.content or "" + + # Store token usage + if hasattr(completion, "usage") and completion.usage: + self._last_token_usage = { + "input_tokens": completion.usage.prompt_tokens or 0, + "output_tokens": completion.usage.completion_tokens or 0, + "total_tokens": completion.usage.total_tokens or 0, + } + + if self.verbose: + print(f" Response in {response_time:.2f}s ({len(response_content):,} chars)") + if hasattr(completion, "usage") and completion.usage: + print(f" Tokens: {completion.usage.total_tokens}") + + return parsed_result + + except Exception as e: + last_err = e + if self.verbose: + print(f" Error: {e}") + if attempt < self.retries - 1: + sleep_time = random.uniform(self.backoff_min, self.backoff_max) + if self.verbose: + print(f" Retrying after {sleep_time:.2f}s...") + time.sleep(sleep_time) + + raise RuntimeError(f"Groq call failed after {self.retries} attempts: {last_err}") + + def raw(self, *, system: str, user: str) -> str: + """Make a plain text call.""" + messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] + + last_err = None + for attempt in range(self.retries): + try: + completion = self.client.chat.completions.create( + model=self.model_name, messages=messages, timeout=self.timeout + ) + + # Store token usage + if hasattr(completion, "usage") and completion.usage: + self._last_token_usage = { + "input_tokens": completion.usage.prompt_tokens or 0, + "output_tokens": completion.usage.completion_tokens or 0, + "total_tokens": completion.usage.total_tokens or 0, + } + + return completion.choices[0].message.content + + except Exception as e: + last_err = e + if attempt < self.retries - 1: + sleep_time = random.uniform(self.backoff_min, self.backoff_max) + time.sleep(sleep_time) + + raise RuntimeError(f"Groq raw call failed after {self.retries} attempts: {last_err}") + + @property + def provider_name(self) -> str: + """Return provider name.""" + return "Groq" + + @property + def supports_thinking(self) -> bool: + """Groq models don't have explicit thinking mode.""" + return False + + def get_last_token_usage(self) -> dict[str, int] | None: + """Return token usage from the last call if available.""" + return self._last_token_usage diff --git a/llm/openrouter_provider.py b/llm/openrouter_provider.py new file mode 100644 index 0000000..46da421 --- /dev/null +++ b/llm/openrouter_provider.py @@ -0,0 +1,195 @@ +"""OpenRouter provider implementation.""" + +from __future__ import annotations + +import os +import random +import time +from typing import Any, TypeVar + +from openai import OpenAI +from pydantic import BaseModel + +from .base_provider import BaseLLMProvider + +T = TypeVar("T", bound=BaseModel) + + +class OpenRouterProvider(BaseLLMProvider): + """OpenRouter API provider implementation. + + OpenRouter provides a unified API for accessing multiple LLM providers + (OpenAI, Anthropic, Meta, Google, etc.) through an OpenAI-compatible interface. + """ + + def __init__( + self, + config: dict[str, Any], + model_name: str, + timeout: int = 120, + retries: int = 3, + backoff_min: float = 2.0, + backoff_max: float = 8.0, + **kwargs, + ): + """Initialize OpenRouter provider.""" + self.config = config + self.model_name = model_name + self.timeout = timeout + self.retries = retries + self.backoff_min = backoff_min + self.backoff_max = backoff_max + + # Verbose logging toggle (suppress request logs by default) + logging_cfg = config.get("logging", {}) if isinstance(config, dict) else {} + env_verbose = os.environ.get("HOUND_LLM_VERBOSE", "").lower() in {"1", "true", "yes", "on"} + self.verbose = kwargs.get( + "verbose", bool(logging_cfg.get("llm_verbose", False) or env_verbose) + ) + self._last_token_usage = None + + # Get API key from environment + openrouter_cfg = config.get("openrouter", {}) if isinstance(config, dict) else {} + api_key_env = openrouter_cfg.get("api_key_env", "OPENROUTER_API_KEY") + api_key = os.environ.get(api_key_env) + if not api_key: + raise ValueError(f"API key not found in environment variable: {api_key_env}") + + # Get base URL from config or use default OpenRouter endpoint + base_url = openrouter_cfg.get("base_url", "https://openrouter.ai/api/v1") + + # Optional OpenRouter-specific headers + self.referer = openrouter_cfg.get("referer", "") + self.app_title = openrouter_cfg.get("app_title", "Hound") + + # Build default headers for OpenRouter + default_headers = {} + if self.referer: + default_headers["HTTP-Referer"] = self.referer + if self.app_title: + default_headers["X-Title"] = self.app_title + + self.client = OpenAI( + api_key=api_key, + base_url=base_url, + default_headers=default_headers if default_headers else None, + ) + + def parse(self, *, system: str, user: str, schema: type[T]) -> T: + """Make a structured call using OpenRouter's API. + + Uses JSON mode with schema instructions since OpenRouter proxies to + multiple providers with varying structured output support. + """ + # Log request details + request_chars = len(system) + len(user) + if self.verbose: + print("\n[OpenRouter Request]") + print(f" Model: {self.model_name}") + print(f" Schema: {schema.__name__}") + print(f" Total prompt: {request_chars:,} chars (~{request_chars//4:,} tokens)") + + last_err = None + + for attempt in range(self.retries): + try: + attempt_start = time.time() + if self.verbose: + print(f" Attempt {attempt + 1}/{self.retries}...") + + # Use JSON mode with schema instruction in prompt + # This works reliably across all OpenRouter models + json_instruction = ( + f"\nRespond with valid JSON matching this schema: {schema.model_json_schema()}" + ) + enhanced_system = system + json_instruction + + messages = [ + {"role": "system", "content": enhanced_system}, + {"role": "user", "content": user}, + ] + + completion = self.client.chat.completions.create( + model=self.model_name, + messages=messages, + timeout=self.timeout, + response_format={"type": "json_object"}, + ) + + # Parse JSON response + json_str = completion.choices[0].message.content + parsed_result = schema.model_validate_json(json_str) + + # Log response details + response_time = time.time() - attempt_start + response_content = completion.choices[0].message.content or "" + + # Store token usage + if hasattr(completion, "usage") and completion.usage: + self._last_token_usage = { + "input_tokens": completion.usage.prompt_tokens or 0, + "output_tokens": completion.usage.completion_tokens or 0, + "total_tokens": completion.usage.total_tokens or 0, + } + + if self.verbose: + print(f" Response in {response_time:.2f}s ({len(response_content):,} chars)") + if hasattr(completion, "usage") and completion.usage: + print(f" Tokens: {completion.usage.total_tokens}") + + return parsed_result + + except Exception as e: + last_err = e + if self.verbose: + print(f" Error: {e}") + if attempt < self.retries - 1: + sleep_time = random.uniform(self.backoff_min, self.backoff_max) + if self.verbose: + print(f" Retrying after {sleep_time:.2f}s...") + time.sleep(sleep_time) + + raise RuntimeError(f"OpenRouter call failed after {self.retries} attempts: {last_err}") + + def raw(self, *, system: str, user: str) -> str: + """Make a plain text call.""" + messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] + + last_err = None + for attempt in range(self.retries): + try: + completion = self.client.chat.completions.create( + model=self.model_name, messages=messages, timeout=self.timeout + ) + + # Store token usage + if hasattr(completion, "usage") and completion.usage: + self._last_token_usage = { + "input_tokens": completion.usage.prompt_tokens or 0, + "output_tokens": completion.usage.completion_tokens or 0, + "total_tokens": completion.usage.total_tokens or 0, + } + + return completion.choices[0].message.content + + except Exception as e: + last_err = e + if attempt < self.retries - 1: + sleep_time = random.uniform(self.backoff_min, self.backoff_max) + time.sleep(sleep_time) + + raise RuntimeError(f"OpenRouter raw call failed after {self.retries} attempts: {last_err}") + + @property + def provider_name(self) -> str: + """Return provider name.""" + return "OpenRouter" + + @property + def supports_thinking(self) -> bool: + """OpenRouter models don't have explicit thinking mode.""" + return False + + def get_last_token_usage(self) -> dict[str, int] | None: + """Return token usage from the last call if available.""" + return self._last_token_usage diff --git a/llm/unified_client.py b/llm/unified_client.py index aa70c23..c03f4fc 100644 --- a/llm/unified_client.py +++ b/llm/unified_client.py @@ -1,4 +1,5 @@ """Unified LLM client that supports multiple providers.""" + from __future__ import annotations import time @@ -9,31 +10,33 @@ from .anthropic_provider import AnthropicProvider from .deepseek_provider import DeepSeekProvider from .gemini_provider import GeminiProvider +from .groq_provider import GroqProvider from .mock_provider import MockProvider from .openai_provider import OpenAIProvider +from .openrouter_provider import OpenRouterProvider from .token_tracker import get_token_tracker from .xai_provider import XAIProvider -T = TypeVar('T', bound=BaseModel) +T = TypeVar("T", bound=BaseModel) class UnifiedLLMClient: """ Unified LLM client that can work with multiple providers. - + Maintains backward compatibility with existing code while supporting new providers like Gemini. """ - + def __init__(self, cfg: dict[str, Any], profile: str = "graph", debug_logger=None): """ Initialize unified LLM client with config and profile. - + The config now supports a 'provider' field for each model profile. If not specified, defaults to 'openai' for backward compatibility. - + Available providers: openai, gemini, anthropic, xai - + Args: cfg: Configuration dictionary profile: Model profile to use (e.g., "graph", "agent") @@ -42,7 +45,7 @@ def __init__(self, cfg: dict[str, Any], profile: str = "graph", debug_logger=Non self.cfg = cfg self.profile = profile self.debug_logger = debug_logger - + # Get model configuration for this profile with backward-compatible mapping models_cfg = cfg.get("models", {}) if isinstance(cfg, dict) else {} profile_key = profile @@ -61,17 +64,19 @@ def __init__(self, cfg: dict[str, Any], profile: str = "graph", debug_logger=Non profile_key = alt break if profile_key not in models_cfg: - raise ValueError(f"Model profile '{profile}' not found in config and no fallback available") + raise ValueError( + f"Model profile '{profile}' not found in config and no fallback available" + ) model_config = models_cfg[profile_key] self.model = model_config["model"] - + # Determine provider (default to openai for backward compatibility) provider_name = model_config.get("provider", "openai").lower() - + # Get provider-specific configuration timeout_cfg = cfg.get("timeouts", {}) retry_cfg = cfg.get("retries", {}) - + common_kwargs = { "config": cfg, "model_name": self.model, @@ -80,13 +85,19 @@ def __init__(self, cfg: dict[str, Any], profile: str = "graph", debug_logger=Non "backoff_min": retry_cfg.get("backoff_min_seconds", 2), "backoff_max": retry_cfg.get("backoff_max_seconds", 8), } - + # Determine verbose logging logging_cfg = cfg.get("logging", {}) if isinstance(cfg, dict) else {} env_verbose = False try: import os as _os - env_verbose = _os.environ.get("HOUND_LLM_VERBOSE", "").lower() in {"1","true","yes","on"} + + env_verbose = _os.environ.get("HOUND_LLM_VERBOSE", "").lower() in { + "1", + "true", + "yes", + "on", + } except Exception: pass llm_verbose = bool(logging_cfg.get("llm_verbose", False) or env_verbose) @@ -97,78 +108,79 @@ def __init__(self, cfg: dict[str, Any], profile: str = "graph", debug_logger=Non **common_kwargs, reasoning_effort=model_config.get("reasoning_effort"), text_verbosity=model_config.get("text_verbosity"), - verbose=llm_verbose + verbose=llm_verbose, ) elif provider_name == "gemini": self.provider = GeminiProvider( **common_kwargs, thinking_enabled=model_config.get("thinking_enabled", False), - thinking_budget=model_config.get("thinking_budget", -1) + thinking_budget=model_config.get("thinking_budget", -1), ) elif provider_name == "anthropic": self.provider = AnthropicProvider( **common_kwargs, api_key_env=cfg.get("anthropic", {}).get("api_key_env", "ANTHROPIC_API_KEY"), verbose=llm_verbose, - thinking_enabled=model_config.get("thinking_enabled", False) + thinking_enabled=model_config.get("thinking_enabled", False), ) elif provider_name == "xai": - self.provider = XAIProvider( - **common_kwargs, - verbose=llm_verbose - ) + self.provider = XAIProvider(**common_kwargs, verbose=llm_verbose) elif provider_name == "deepseek": - self.provider = DeepSeekProvider( - **common_kwargs, - verbose=llm_verbose - ) + self.provider = DeepSeekProvider(**common_kwargs, verbose=llm_verbose) + elif provider_name == "openrouter": + self.provider = OpenRouterProvider(**common_kwargs, verbose=llm_verbose) + elif provider_name == "groq": + self.provider = GroqProvider(**common_kwargs, verbose=llm_verbose) elif provider_name == "mock": # Mock provider for testing mock_instance = model_config.get("mock_instance") - self.provider = MockProvider( - **common_kwargs, - mock_instance=mock_instance - ) + self.provider = MockProvider(**common_kwargs, mock_instance=mock_instance) else: raise ValueError(f"Unknown provider: {provider_name}") # Only surface provider init when verbose explicitly enabled if llm_verbose: - print(f"[*] Initialized {self.provider.provider_name} provider with model: {self.model}") - if self.provider.supports_thinking and hasattr(self.provider, 'thinking_enabled'): + print( + f"[*] Initialized {self.provider.provider_name} provider with model: {self.model}" + ) + if self.provider.supports_thinking and hasattr(self.provider, "thinking_enabled"): if self.provider.thinking_enabled: print(" Thinking mode: Enabled") - - def parse(self, *, system: str, user: str, schema: type[T], reasoning_effort: str | None = None) -> T: + + def parse( + self, *, system: str, user: str, schema: type[T], reasoning_effort: str | None = None + ) -> T: """ Structured call: returns an instance of `schema` (Pydantic BaseModel). - + Delegates to the underlying provider. """ start_time = time.time() error = None response = None - + try: try: - response = self.provider.parse(system=system, user=user, schema=schema, reasoning_effort=reasoning_effort) + response = self.provider.parse( + system=system, user=user, schema=schema, reasoning_effort=reasoning_effort + ) except TypeError: # Provider may not support per-call overrides response = self.provider.parse(system=system, user=user, schema=schema) - + # Track token usage if provider supports it - if hasattr(self.provider, 'get_last_token_usage'): + if hasattr(self.provider, "get_last_token_usage"): token_usage = self.provider.get_last_token_usage() if token_usage: tracker = get_token_tracker() tracker.track_usage( provider=self.provider.provider_name, model=self.model, - input_tokens=token_usage.get('input_tokens', 0), - output_tokens=token_usage.get('output_tokens', 0), - profile=self.profile + input_tokens=token_usage.get("input_tokens", 0), + output_tokens=token_usage.get("output_tokens", 0), + profile=self.profile, ) - + return response except Exception as e: error = str(e) @@ -180,42 +192,46 @@ def parse(self, *, system: str, user: str, schema: type[T], reasoning_effort: st self.debug_logger.log_interaction( system_prompt=system, user_prompt=user, - response=response.dict() if response and hasattr(response, 'dict') else response, + response=( + response.dict() if response and hasattr(response, "dict") else response + ), schema=schema, duration=duration, error=error, - profile=self.profile + profile=self.profile, ) - + def raw(self, *, system: str, user: str, reasoning_effort: str | None = None) -> str: """ Plain text call (no schema). - + Delegates to the underlying provider. """ start_time = time.time() error = None response = None - + try: try: - response = self.provider.raw(system=system, user=user, reasoning_effort=reasoning_effort) + response = self.provider.raw( + system=system, user=user, reasoning_effort=reasoning_effort + ) except TypeError: response = self.provider.raw(system=system, user=user) - + # Track token usage if provider supports it - if hasattr(self.provider, 'get_last_token_usage'): + if hasattr(self.provider, "get_last_token_usage"): token_usage = self.provider.get_last_token_usage() if token_usage: tracker = get_token_tracker() tracker.track_usage( provider=self.provider.provider_name, model=self.model, - input_tokens=token_usage.get('input_tokens', 0), - output_tokens=token_usage.get('output_tokens', 0), - profile=self.profile + input_tokens=token_usage.get("input_tokens", 0), + output_tokens=token_usage.get("output_tokens", 0), + profile=self.profile, ) - + return response except Exception as e: error = str(e) @@ -231,21 +247,21 @@ def raw(self, *, system: str, user: str, reasoning_effort: str | None = None) -> schema=None, duration=duration, error=error, - profile=self.profile + profile=self.profile, ) - + def generate(self, *, system: str, user: str) -> str: """ Alias for raw() - plain text generation. Added for compatibility with agent code. """ return self.raw(system=system, user=user) - + @property def provider_name(self) -> str: """Get the name of the current provider.""" return self.provider.provider_name - + @property def supports_thinking(self) -> bool: """Check if the current provider/model supports thinking mode.""" diff --git a/tests/test_provider_token_tracking.py b/tests/test_provider_token_tracking.py index e14f80d..dcd89c6 100644 --- a/tests/test_provider_token_tracking.py +++ b/tests/test_provider_token_tracking.py @@ -1,4 +1,5 @@ """Unit tests for LLM provider token tracking.""" + import sys import unittest from pathlib import Path @@ -8,55 +9,51 @@ from llm.anthropic_provider import AnthropicProvider from llm.gemini_provider import GeminiProvider +from llm.groq_provider import GroqProvider from llm.openai_provider import OpenAIProvider +from llm.openrouter_provider import OpenRouterProvider from llm.xai_provider import XAIProvider class TestOpenAIProviderTokenTracking(unittest.TestCase): """Test OpenAI provider token tracking.""" - - @patch.dict('os.environ', {'OPENAI_API_KEY': 'test_key'}) - @patch('llm.openai_provider.OpenAI') + + @patch.dict("os.environ", {"OPENAI_API_KEY": "test_key"}) + @patch("llm.openai_provider.OpenAI") def setUp(self, mock_openai): """Set up test fixtures.""" self.mock_client = Mock() mock_openai.return_value = self.mock_client - + self.provider = OpenAIProvider( - config={'openai': {'api_key_env': 'OPENAI_API_KEY'}}, - model_name='gpt-4', - timeout=30 + config={"openai": {"api_key_env": "OPENAI_API_KEY"}}, model_name="gpt-4", timeout=30 ) - + def test_initial_token_usage_is_none(self): """Test that initial token usage is None.""" self.assertIsNone(self.provider.get_last_token_usage()) - + def test_raw_call_tracks_tokens(self): """Test that raw calls track token usage.""" # Mock response with usage data mock_response = Mock() mock_response.choices = [Mock(message=Mock(content="Test response"))] - mock_response.usage = Mock( - prompt_tokens=100, - completion_tokens=50, - total_tokens=150 - ) + mock_response.usage = Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150) self.mock_client.chat.completions.create.return_value = mock_response - + # Make a raw call result = self.provider.raw(system="Test system", user="Test user") - + # Check result self.assertEqual(result, "Test response") - + # Check token tracking usage = self.provider.get_last_token_usage() self.assertIsNotNone(usage) - self.assertEqual(usage['input_tokens'], 100) - self.assertEqual(usage['output_tokens'], 50) - self.assertEqual(usage['total_tokens'], 150) - + self.assertEqual(usage["input_tokens"], 100) + self.assertEqual(usage["output_tokens"], 50) + self.assertEqual(usage["total_tokens"], 150) + def test_parse_call_tracks_tokens(self): """Test that parse calls track token usage.""" # Mock response with usage data @@ -66,167 +63,232 @@ def test_parse_call_tracks_tokens(self): mock_message.parsed = Mock(test="data") mock_message.refusal = None mock_response.choices = [Mock(message=mock_message)] - mock_response.usage = Mock( - prompt_tokens=200, - completion_tokens=100, - total_tokens=300 - ) + mock_response.usage = Mock(prompt_tokens=200, completion_tokens=100, total_tokens=300) self.mock_client.beta.chat.completions.parse.return_value = mock_response - + # Create a mock schema class TestSchema: def __init__(self, **kwargs): for k, v in kwargs.items(): setattr(self, k, v) - + # Make a parse call self.provider.parse(system="Test", user="Test", schema=TestSchema) - + # Check token tracking usage = self.provider.get_last_token_usage() self.assertIsNotNone(usage) - self.assertEqual(usage['input_tokens'], 200) - self.assertEqual(usage['output_tokens'], 100) - self.assertEqual(usage['total_tokens'], 300) + self.assertEqual(usage["input_tokens"], 200) + self.assertEqual(usage["output_tokens"], 100) + self.assertEqual(usage["total_tokens"], 300) class TestGeminiProviderTokenTracking(unittest.TestCase): """Test Gemini provider token tracking.""" - - @patch.dict('os.environ', {'GOOGLE_API_KEY': 'test_key'}) - @patch('llm.gemini_provider._USE_NEW_GENAI', False) - @patch('llm.gemini_provider._genai_new', None) - @patch('llm.gemini_provider._genai_legacy', None) - @patch('llm.gemini_provider.genai') + + @patch.dict("os.environ", {"GOOGLE_API_KEY": "test_key"}) + @patch("llm.gemini_provider._USE_NEW_GENAI", False) + @patch("llm.gemini_provider._genai_new", None) + @patch("llm.gemini_provider._genai_legacy", None) + @patch("llm.gemini_provider.genai") def setUp(self, mock_genai, *_): """Set up test fixtures.""" mock_genai.configure = Mock() self.mock_model = Mock() mock_genai.GenerativeModel.return_value = self.mock_model - + self.provider = GeminiProvider( - config={'gemini': {'api_key_env': 'GOOGLE_API_KEY'}}, - model_name='gemini-2.0-flash' + config={"gemini": {"api_key_env": "GOOGLE_API_KEY"}}, model_name="gemini-2.0-flash" ) - + def test_initial_token_usage_is_none(self): """Test that initial token usage is None.""" self.assertIsNone(self.provider.get_last_token_usage()) - + def test_raw_call_tracks_tokens(self): """Test that raw calls track token usage.""" # Mock response with usage metadata mock_response = Mock() mock_response.text = "Test response" mock_response.usage_metadata = Mock( - prompt_token_count=150, - candidates_token_count=75, - total_token_count=225 + prompt_token_count=150, candidates_token_count=75, total_token_count=225 ) self.mock_model.generate_content.return_value = mock_response - + # Make a raw call result = self.provider.raw(system="Test system", user="Test user") - + # Check result self.assertEqual(result, "Test response") - + # Check token tracking usage = self.provider.get_last_token_usage() self.assertIsNotNone(usage) - self.assertEqual(usage['input_tokens'], 150) - self.assertEqual(usage['output_tokens'], 75) - self.assertEqual(usage['total_tokens'], 225) + self.assertEqual(usage["input_tokens"], 150) + self.assertEqual(usage["output_tokens"], 75) + self.assertEqual(usage["total_tokens"], 225) class TestAnthropicProviderTokenTracking(unittest.TestCase): """Test Anthropic provider token tracking.""" - - @patch.dict('os.environ', {'ANTHROPIC_API_KEY': 'test_key'}) - @patch('anthropic.Anthropic') + + @patch.dict("os.environ", {"ANTHROPIC_API_KEY": "test_key"}) + @patch("anthropic.Anthropic") def setUp(self, mock_anthropic_class): """Set up test fixtures.""" self.mock_client = Mock() mock_anthropic_class.return_value = self.mock_client - + self.provider = AnthropicProvider( - model_name='claude-3-5-sonnet', - api_key_env='ANTHROPIC_API_KEY' + model_name="claude-3-5-sonnet", api_key_env="ANTHROPIC_API_KEY" ) - + def test_initial_token_usage_is_none(self): """Test that initial token usage is None.""" self.assertIsNone(self.provider.get_last_token_usage()) - + def test_raw_call_tracks_tokens(self): """Test that raw calls track token usage.""" # Mock response with usage data mock_response = Mock() mock_response.content = [Mock(text="Test response")] - mock_response.usage = Mock( - input_tokens=120, - output_tokens=60 - ) + mock_response.usage = Mock(input_tokens=120, output_tokens=60) self.mock_client.messages.create.return_value = mock_response - + # Make a raw call result = self.provider.raw(system="Test system", user="Test user") - + # Check result self.assertEqual(result, "Test response") - + # Check token tracking usage = self.provider.get_last_token_usage() self.assertIsNotNone(usage) - self.assertEqual(usage['input_tokens'], 120) - self.assertEqual(usage['output_tokens'], 60) - self.assertEqual(usage['total_tokens'], 180) + self.assertEqual(usage["input_tokens"], 120) + self.assertEqual(usage["output_tokens"], 60) + self.assertEqual(usage["total_tokens"], 180) class TestXAIProviderTokenTracking(unittest.TestCase): """Test XAI provider token tracking.""" - - @patch.dict('os.environ', {'XAI_API_KEY': 'test_key'}) - @patch('llm.xai_provider.OpenAI') + + @patch.dict("os.environ", {"XAI_API_KEY": "test_key"}) + @patch("llm.xai_provider.OpenAI") def setUp(self, mock_openai): """Set up test fixtures.""" self.mock_client = Mock() mock_openai.return_value = self.mock_client - + self.provider = XAIProvider( - config={'xai': {'api_key_env': 'XAI_API_KEY'}}, - model_name='grok-2' + config={"xai": {"api_key_env": "XAI_API_KEY"}}, model_name="grok-2" ) - + def test_initial_token_usage_is_none(self): """Test that initial token usage is None.""" self.assertIsNone(self.provider.get_last_token_usage()) - + def test_raw_call_tracks_tokens(self): """Test that raw calls track token usage.""" # Mock response with usage data (XAI uses OpenAI-compatible API) mock_response = Mock() mock_response.choices = [Mock(message=Mock(content="Test response"))] - mock_response.usage = Mock( - prompt_tokens=90, - completion_tokens=45, - total_tokens=135 + mock_response.usage = Mock(prompt_tokens=90, completion_tokens=45, total_tokens=135) + self.mock_client.chat.completions.create.return_value = mock_response + + # Make a raw call + result = self.provider.raw(system="Test system", user="Test user") + + # Check result + self.assertEqual(result, "Test response") + + # Check token tracking + usage = self.provider.get_last_token_usage() + self.assertIsNotNone(usage) + self.assertEqual(usage["input_tokens"], 90) + self.assertEqual(usage["output_tokens"], 45) + self.assertEqual(usage["total_tokens"], 135) + + +class TestOpenRouterProviderTokenTracking(unittest.TestCase): + """Test OpenRouter provider token tracking.""" + + @patch.dict("os.environ", {"OPENROUTER_API_KEY": "test_key"}) + @patch("llm.openrouter_provider.OpenAI") + def setUp(self, mock_openai): + """Set up test fixtures.""" + self.mock_client = Mock() + mock_openai.return_value = self.mock_client + + self.provider = OpenRouterProvider( + config={"openrouter": {"api_key_env": "OPENROUTER_API_KEY"}}, + model_name="meta-llama/llama-3.1-70b-instruct", + ) + + def test_initial_token_usage_is_none(self): + """Test that initial token usage is None.""" + self.assertIsNone(self.provider.get_last_token_usage()) + + def test_raw_call_tracks_tokens(self): + """Test that raw calls track token usage.""" + # Mock response with usage data (OpenRouter uses OpenAI-compatible API) + mock_response = Mock() + mock_response.choices = [Mock(message=Mock(content="Test response"))] + mock_response.usage = Mock(prompt_tokens=80, completion_tokens=40, total_tokens=120) + self.mock_client.chat.completions.create.return_value = mock_response + + # Make a raw call + result = self.provider.raw(system="Test system", user="Test user") + + # Check result + self.assertEqual(result, "Test response") + + # Check token tracking + usage = self.provider.get_last_token_usage() + self.assertIsNotNone(usage) + self.assertEqual(usage["input_tokens"], 80) + self.assertEqual(usage["output_tokens"], 40) + self.assertEqual(usage["total_tokens"], 120) + + +class TestGroqProviderTokenTracking(unittest.TestCase): + """Test Groq provider token tracking.""" + + @patch.dict("os.environ", {"GROQ_API_KEY": "test_key"}) + @patch("llm.groq_provider.OpenAI") + def setUp(self, mock_openai): + """Set up test fixtures.""" + self.mock_client = Mock() + mock_openai.return_value = self.mock_client + + self.provider = GroqProvider( + config={"groq": {"api_key_env": "GROQ_API_KEY"}}, model_name="llama-3.3-70b-versatile" ) + + def test_initial_token_usage_is_none(self): + """Test that initial token usage is None.""" + self.assertIsNone(self.provider.get_last_token_usage()) + + def test_raw_call_tracks_tokens(self): + """Test that raw calls track token usage.""" + # Mock response with usage data (Groq uses OpenAI-compatible API) + mock_response = Mock() + mock_response.choices = [Mock(message=Mock(content="Test response"))] + mock_response.usage = Mock(prompt_tokens=70, completion_tokens=35, total_tokens=105) self.mock_client.chat.completions.create.return_value = mock_response - + # Make a raw call result = self.provider.raw(system="Test system", user="Test user") - + # Check result self.assertEqual(result, "Test response") - + # Check token tracking usage = self.provider.get_last_token_usage() self.assertIsNotNone(usage) - self.assertEqual(usage['input_tokens'], 90) - self.assertEqual(usage['output_tokens'], 45) - self.assertEqual(usage['total_tokens'], 135) + self.assertEqual(usage["input_tokens"], 70) + self.assertEqual(usage["output_tokens"], 35) + self.assertEqual(usage["total_tokens"], 105) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()