diff --git a/src/neo4j_graphrag/llm/types.py b/src/neo4j_graphrag/llm/types.py index 34d68c33..a05835ed 100644 --- a/src/neo4j_graphrag/llm/types.py +++ b/src/neo4j_graphrag/llm/types.py @@ -17,8 +17,16 @@ def __getattr__(name: str) -> Any: raise AttributeError(f"module {__name__!r} has no attribute {name!r}") +class LLMUsage(BaseModel): + prompt_token_count: int + candidates_token_count: int + thoughts_token_count: int + total_token_count: int + + class LLMResponse(BaseModel): content: str + usage: Optional[LLMUsage] = None class BaseMessage(BaseModel): diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 0b492697..792f21b3 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -30,6 +30,7 @@ MessageList, ToolCall, ToolCallResponse, + LLMUsage, ) from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.tool import Tool @@ -285,6 +286,12 @@ def _parse_tool_response(self, response: GenerationResponse) -> ToolCallResponse def _parse_content_response(self, response: GenerationResponse) -> LLMResponse: return LLMResponse( content=response.text, + usage=LLMUsage( + prompt_token_count=response.usage_metadata.prompt_token_count, + candidates_token_count=response.usage_metadata.candidates_token_count, + total_token_count=response.usage_metadata.total_token_count, + thoughts_token_count=response.usage_metadata.thoughts_token_count, + ), ) async def ainvoke_with_tools(