Skip to content

Commit 7ea7475

Browse files
authored
fix: Improve streaming errors handling (#576)
# Description Refine error management for the streaming operation. Previously, errors were converted into stream parts, resulting in the loss of status info. The updated logic now first verifies if the request was successful; if it failed, a client error is returned, preserving the relevant status information. - [x] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [x] Make your Pull Request title in the <https://www.conventionalcommits.org/> specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [x] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [x] Appropriate docs were updated (if necessary) Fixes #502 🦕
1 parent 3bfbea9 commit 7ea7475

File tree

4 files changed

+84
-2
lines changed

4 files changed

+84
-2
lines changed

src/a2a/client/transports/jsonrpc.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,16 @@ async def send_message_streaming(
174174
**modified_kwargs,
175175
) as event_source:
176176
try:
177+
event_source.response.raise_for_status()
177178
async for sse in event_source.aiter_sse():
178179
response = SendStreamingMessageResponse.model_validate(
179180
json.loads(sse.data)
180181
)
181182
if isinstance(response.root, JSONRPCErrorResponse):
182183
raise A2AClientJSONRPCError(response.root)
183184
yield response.root.result
185+
except httpx.HTTPStatusError as e:
186+
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
184187
except SSEError as e:
185188
raise A2AClientHTTPError(
186189
400, f'Invalid SSE response or protocol error: {e}'

src/a2a/client/transports/rest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,13 @@ async def send_message_streaming(
152152
**modified_kwargs,
153153
) as event_source:
154154
try:
155+
event_source.response.raise_for_status()
155156
async for sse in event_source.aiter_sse():
156157
event = a2a_pb2.StreamResponse()
157158
Parse(sse.data, event)
158159
yield proto_utils.FromProto.stream_response(event)
160+
except httpx.HTTPStatusError as e:
161+
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
159162
except SSEError as e:
160163
raise A2AClientHTTPError(
161164
400, f'Invalid SSE response or protocol error: {e}'

tests/client/transports/test_jsonrpc_client.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,44 @@ async def test_send_message_streaming_with_new_extensions(
880880
},
881881
)
882882

883+
@pytest.mark.asyncio
884+
@patch('a2a.client.transports.jsonrpc.aconnect_sse')
885+
async def test_send_message_streaming_server_error_propagates(
886+
self,
887+
mock_aconnect_sse: AsyncMock,
888+
mock_httpx_client: AsyncMock,
889+
mock_agent_card: MagicMock,
890+
):
891+
"""Test that send_message_streaming propagates server errors (e.g., 403, 500) directly."""
892+
client = JsonRpcTransport(
893+
httpx_client=mock_httpx_client,
894+
agent_card=mock_agent_card,
895+
)
896+
params = MessageSendParams(
897+
message=create_text_message_object(content='Error stream')
898+
)
899+
900+
mock_event_source = AsyncMock(spec=EventSource)
901+
mock_response = MagicMock(spec=httpx.Response)
902+
mock_response.status_code = 403
903+
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
904+
'Forbidden',
905+
request=httpx.Request('POST', 'http://test.url'),
906+
response=mock_response,
907+
)
908+
mock_event_source.response = mock_response
909+
mock_event_source.aiter_sse.return_value = async_iterable_from_list([])
910+
mock_aconnect_sse.return_value.__aenter__.return_value = (
911+
mock_event_source
912+
)
913+
914+
with pytest.raises(A2AClientHTTPError) as exc_info:
915+
async for _ in client.send_message_streaming(request=params):
916+
pass
917+
918+
assert exc_info.value.status_code == 403
919+
mock_aconnect_sse.assert_called_once()
920+
883921
@pytest.mark.asyncio
884922
async def test_get_card_no_card_provided_with_extensions(
885923
self, mock_httpx_client: AsyncMock

tests/client/transports/test_rest_client.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@
77
from httpx_sse import EventSource, ServerSentEvent
88

99
from a2a.client import create_text_message_object
10+
from a2a.client.errors import A2AClientHTTPError
1011
from a2a.client.transports.rest import RestTransport
1112
from a2a.extensions.common import HTTP_EXTENSION_HEADER
1213
from a2a.types import (
1314
AgentCapabilities,
1415
AgentCard,
15-
AgentSkill,
1616
MessageSendParams,
17-
Role,
1817
)
1918

2019

@@ -130,6 +129,45 @@ async def test_send_message_streaming_with_new_extensions(
130129
},
131130
)
132131

132+
@pytest.mark.asyncio
133+
@patch('a2a.client.transports.rest.aconnect_sse')
134+
async def test_send_message_streaming_server_error_propagates(
135+
self,
136+
mock_aconnect_sse: AsyncMock,
137+
mock_httpx_client: AsyncMock,
138+
mock_agent_card: MagicMock,
139+
):
140+
"""Test that send_message_streaming propagates server errors (e.g., 403, 500) directly."""
141+
client = RestTransport(
142+
httpx_client=mock_httpx_client,
143+
agent_card=mock_agent_card,
144+
)
145+
params = MessageSendParams(
146+
message=create_text_message_object(content='Error stream')
147+
)
148+
149+
mock_event_source = AsyncMock(spec=EventSource)
150+
mock_response = MagicMock(spec=httpx.Response)
151+
mock_response.status_code = 403
152+
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
153+
'Forbidden',
154+
request=httpx.Request('POST', 'http://test.url'),
155+
response=mock_response,
156+
)
157+
mock_event_source.response = mock_response
158+
mock_event_source.aiter_sse.return_value = async_iterable_from_list([])
159+
mock_aconnect_sse.return_value.__aenter__.return_value = (
160+
mock_event_source
161+
)
162+
163+
with pytest.raises(A2AClientHTTPError) as exc_info:
164+
async for _ in client.send_message_streaming(request=params):
165+
pass
166+
167+
assert exc_info.value.status_code == 403
168+
169+
mock_aconnect_sse.assert_called_once()
170+
133171
@pytest.mark.asyncio
134172
async def test_get_card_no_card_provided_with_extensions(
135173
self, mock_httpx_client: AsyncMock

0 commit comments

Comments
 (0)