Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:
vertex_messages.append(Content(role=role, parts=parts))
elif isinstance(message, FunctionMessage):
prev_ai_message = None
role = "function"
role = "user"

part = Part(
function_response=FunctionResponse(
Expand All @@ -453,9 +453,7 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:
parts = [part]
if vertex_messages:
prev_content = vertex_messages[-1]
prev_content_is_function = (
prev_content and prev_content.role == "function"
)
prev_content_is_function = prev_content and prev_content.role == "user"
if prev_content_is_function:
prev_parts = list(prev_content.parts)
prev_parts.extend(parts)
Expand All @@ -465,7 +463,7 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:

vertex_messages.append(Content(role=role, parts=parts))
elif isinstance(message, ToolMessage):
role = "function"
role = "user"

# message.name can be null for ToolMessage
name = message.name
Expand Down Expand Up @@ -527,9 +525,9 @@ def _parse_content(raw_content: str | Dict[Any, Any]) -> Dict[Any, Any]:
parts = [part]

prev_content = vertex_messages[-1]
prev_content_is_function = prev_content and prev_content.role == "function"
prev_content_is_tool_response = prev_content and prev_content.role == "user"

if prev_content_is_function:
if prev_content_is_tool_response:
prev_parts = list(prev_content.parts)
prev_parts.extend(parts)
# replacing last message
Expand Down
10 changes: 5 additions & 5 deletions libs/vertexai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def test_parse_history_gemini_function_empty_list() -> None:
name=tool_call_1["name"], args=tool_call_1["args"]
)

assert history[2].role == "function"
assert history[2].role == "user"
assert history[2].parts[0].function_response == FunctionResponse(
name=fn_name_1,
response={"content": ""},
Expand Down Expand Up @@ -414,7 +414,7 @@ def test_parse_history_gemini_function() -> None:
name=tool_call_2["name"], args=tool_call_2["args"]
)

assert history[2].role == "function"
assert history[2].role == "user"
assert history[2].parts[0].function_response == FunctionResponse(
name=fn_name_1,
response={"content": message3.content},
Expand All @@ -429,7 +429,7 @@ def test_parse_history_gemini_function() -> None:
name=tool_call_3["name"], args=tool_call_3["args"]
)

assert history[4].role == "function"
assert history[4].role == "user"
assert history[4].parts[0].function_response == FunctionResponse(
name=fn_name_3,
response={"content": message6.content},
Expand Down Expand Up @@ -1246,7 +1246,7 @@ def test_multiple_fc() -> None:
)
),
],
role="function",
role="user",
),
]
assert history == expected
Expand Down Expand Up @@ -1549,7 +1549,7 @@ def test_thought_signature() -> None:
],
),
Content(
role="function",
role="user",
parts=[
Part(
function_response=FunctionResponse(
Expand Down