diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index 5dc5bca3a..07ac9c7d0 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -443,7 +443,7 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]: vertex_messages.append(Content(role=role, parts=parts)) elif isinstance(message, FunctionMessage): prev_ai_message = None - role = "function" + role = "user" part = Part( function_response=FunctionResponse( @@ -453,9 +453,7 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]: parts = [part] if vertex_messages: prev_content = vertex_messages[-1] - prev_content_is_function = ( - prev_content and prev_content.role == "function" - ) + prev_content_is_function = prev_content and prev_content.role == "user" if prev_content_is_function: prev_parts = list(prev_content.parts) prev_parts.extend(parts) @@ -465,7 +463,7 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]: vertex_messages.append(Content(role=role, parts=parts)) elif isinstance(message, ToolMessage): - role = "function" + role = "user" # message.name can be null for ToolMessage name = message.name @@ -527,9 +525,9 @@ def _parse_content(raw_content: str | Dict[Any, Any]) -> Dict[Any, Any]: parts = [part] prev_content = vertex_messages[-1] - prev_content_is_function = prev_content and prev_content.role == "function" + prev_content_is_tool_response = prev_content and prev_content.role == "user" - if prev_content_is_function: + if prev_content_is_tool_response: prev_parts = list(prev_content.parts) prev_parts.extend(parts) # replacing last message diff --git a/libs/vertexai/tests/unit_tests/test_chat_models.py b/libs/vertexai/tests/unit_tests/test_chat_models.py index d6d92ec02..8487efe42 100644 --- a/libs/vertexai/tests/unit_tests/test_chat_models.py +++ b/libs/vertexai/tests/unit_tests/test_chat_models.py @@ -333,7 +333,7 @@ def test_parse_history_gemini_function_empty_list() -> None: name=tool_call_1["name"], args=tool_call_1["args"] ) - assert history[2].role == "function" + assert history[2].role == "user" assert history[2].parts[0].function_response == FunctionResponse( name=fn_name_1, response={"content": ""}, @@ -414,7 +414,7 @@ def test_parse_history_gemini_function() -> None: name=tool_call_2["name"], args=tool_call_2["args"] ) - assert history[2].role == "function" + assert history[2].role == "user" assert history[2].parts[0].function_response == FunctionResponse( name=fn_name_1, response={"content": message3.content}, @@ -429,7 +429,7 @@ def test_parse_history_gemini_function() -> None: name=tool_call_3["name"], args=tool_call_3["args"] ) - assert history[4].role == "function" + assert history[4].role == "user" assert history[4].parts[0].function_response == FunctionResponse( name=fn_name_3, response={"content": message6.content}, @@ -1246,7 +1246,7 @@ def test_multiple_fc() -> None: ) ), ], - role="function", + role="user", ), ] assert history == expected @@ -1549,7 +1549,7 @@ def test_thought_signature() -> None: ], ), Content( - role="function", + role="user", parts=[ Part( function_response=FunctionResponse(