diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 0b492697..38aa72d4 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -34,17 +34,17 @@ from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.tool import Tool from neo4j_graphrag.types import LLMMessage +from google import genai +from google.genai import types try: from vertexai.generative_models import ( Content, FunctionCall, - FunctionDeclaration, GenerationResponse, GenerativeModel, Part, ResponseValidationError, - Tool as VertexAITool, ToolConfig, ) except ImportError: @@ -189,8 +189,8 @@ async def ainvoke( except ResponseValidationError as e: raise LLMGenerationError("Error calling VertexAILLM") from e - def _to_vertexai_function_declaration(self, tool: Tool) -> FunctionDeclaration: - return FunctionDeclaration( + def _to_vertexai_function_declaration(self, tool: Tool) -> types.FunctionDeclaration: + return types.FunctionDeclaration( name=tool.get_name(), description=tool.get_description(), parameters=tool.get_parameters(exclude=["additional_properties"]), @@ -198,11 +198,11 @@ def _to_vertexai_function_declaration(self, tool: Tool) -> FunctionDeclaration: def _get_llm_tools( self, tools: Optional[Sequence[Tool]] - ) -> Optional[list[VertexAITool]]: + ) -> Optional[list[types.Tool]]: if not tools: return None return [ - VertexAITool( + types.Tool( function_declarations=[ self._to_vertexai_function_declaration(tool) for tool in tools ] @@ -254,7 +254,20 @@ async def _acall_llm( ) -> GenerationResponse: model = self._get_model(system_instruction=system_instruction) options = self._get_call_params(input, message_history, tools) - response = await model.generate_content_async(**options) + + client = genai.Client() + response = await client.aio.models.generate_content( + model=self.model_name, + contents=types.Content( + role="user", + parts=[types.Part.from_text(text=input)], + ), + config=types.GenerateContentConfig( + tools=options["tools"], + system_instruction=system_instruction, + temperature=0.0, + ), + ) return response # type: ignore[no-any-return] def _call_llm( @@ -276,7 +289,7 @@ def _to_tool_call(self, function_call: FunctionCall) -> ToolCall: ) def _parse_tool_response(self, response: GenerationResponse) -> ToolCallResponse: - function_calls = response.candidates[0].function_calls + function_calls = response.function_calls return ToolCallResponse( tool_calls=[self._to_tool_call(f) for f in function_calls], content=None,