diff --git a/dspy/streaming/streaming_listener.py b/dspy/streaming/streaming_listener.py index eb0cc18f91..626b104821 100644 --- a/dspy/streaming/streaming_listener.py +++ b/dspy/streaming/streaming_listener.py @@ -3,6 +3,7 @@ from queue import Queue from typing import TYPE_CHECKING, Any +import jiter from litellm import ModelResponseStream from dspy.adapters.chat_adapter import ChatAdapter @@ -49,6 +50,8 @@ def __init__( self.cache_hit = False self.allow_reuse = allow_reuse + self.json_adapter_state = {"field_accumulated_messages": ""} + self.adapter_identifiers = { "ChatAdapter": { "start_identifier": f"[[ ## {self.signature_field_name} ## ]]", @@ -62,7 +65,7 @@ def __init__( "end_identifier": re.compile(r"\w*\"(,|\s*})"), "start_indicator": '"', "end_pattern_prefixes": ['"', '",', '" ', '"}'], - "end_pattern_contains": None, + "end_pattern_contains": "}", }, "XMLAdapter": { "start_identifier": f"<{self.signature_field_name}>", @@ -126,6 +129,7 @@ def receive(self, chunk: ModelResponseStream): self.cache_hit = False self.field_start_queue = [] self.field_end_queue = Queue() + self.json_adapter_state["field_accumulated_messages"] = "" self.stream_start = False else: return @@ -147,7 +151,7 @@ def receive(self, chunk: ModelResponseStream): is_last_chunk=self.stream_end, ) - if chunk_message and start_identifier in chunk_message: + if chunk_message and start_identifier in chunk_message and not isinstance(settings.adapter, JSONAdapter): # If the cache is hit, the chunk_message could be the full response. When it happens we can # directly end the stream listening. In some models like gemini, each stream chunk can be multiple # tokens, so it's possible that response only has one chunk, we also fall back to this logic. @@ -180,10 +184,13 @@ def receive(self, chunk: ModelResponseStream): # Keep the part after the start_identifier from the concat_message, we need to write it to the buffer. value_start_index = concat_message.find(start_identifier) + len(start_identifier) chunk_message = concat_message[value_start_index:].lstrip() - if isinstance(settings.adapter, JSONAdapter) and chunk_message.startswith('"'): - # For JSONAdapter, we need to remove the leading ". We cannot do this with the start_identifier - # because there could be a few splitters between ':' and '"', e.g., '"name": "value"'. - chunk_message = chunk_message[1:] + + if isinstance(settings.adapter, JSONAdapter): + # For JSONAdapter, we rely on partial json parsing to detect the end of the field we are listening + # to, so we need to maintain a few extra states to help us with that. + # We add an extra "{" to the beginning of the field_accumulated_messages, so we can detect the + # appearance of the next key. + self.json_adapter_state["field_accumulated_messages"] += "{" + start_identifier elif self._buffered_message_end_with_start_identifier(concat_message.strip(), start_identifier): # If the buffered message ends with part of the start_identifier, we keep looking for the @@ -196,30 +203,98 @@ def receive(self, chunk: ModelResponseStream): if self.stream_start and chunk_message: # The stream is started, we keep returning the token until we see the start of the next field. - token = None self.field_end_queue.put(chunk_message) + token = None concat_message = "".join(self.field_end_queue.queue).strip() - if re.search(end_identifier, concat_message): - # The next field is identified, we can end the stream and flush out all tokens in the buffer. - self.stream_end = True - token = self.flush() - token = token.rstrip() # Remove the trailing \n\n - elif not self._could_form_end_identifier(concat_message, adapter_name): + + if not self._could_form_end_identifier(concat_message, adapter_name): # Buffer cannot form end identifier, safe to flush out the tokens in the buffer. token = self.flush() elif self.field_end_queue.qsize() > 10: - # Buffer could form end identifier, but we've exceeded max buffer size - # Yield the oldest token to prevent unbounded buffering + # We keep the last 10 tokens in the buffer if they can potentially form the end_identifier to avoid + # sending the DSPy boilerplate tokens to users. 10 is a heuristic number that is sufficient to capture + # the end_identifier for all LMs. token = self.field_end_queue.get() - if token: + if isinstance(settings.adapter, JSONAdapter): + # JSONAdapter uses partial json parsing to detect the end of the field we are listening to, instead of + # relying on the end_identifier. + return self._json_adapter_handle_stream_chunk(token, chunk_message) + else: + # Other adapters rely on the end_identifier to detect the end of the field we are listening to. + return self._default_handle_stream_chunk(token, end_identifier) + + def _json_adapter_handle_stream_chunk(self, token: str, chunk_message: str) -> StreamResponse | None: + self.json_adapter_state["field_accumulated_messages"] += chunk_message + if self.json_adapter_state["field_accumulated_messages"].rstrip().endswith("}"): + # When the accumulated tokens end with a curly bracket, that means the streaming for the `dspy.Predict` we + # are listening to is probably finished, we need to run a check and decide whether to end the stream. + try: + # If the parse doesn't raise an error, that means the accumulated tokens is a valid json object. Because + # we add an extra "{" to the beginning of the field_accumulated_messages, so we know the streaming is + # finished. + jiter.from_json(self.json_adapter_state["field_accumulated_messages"].encode("utf-8")) + self.stream_end = True + last_token = self.flush() + right_curly_bracket_index = last_token.rfind("}") + token = ( + token + last_token[:right_curly_bracket_index] if token else last_token[:right_curly_bracket_index] + ) return StreamResponse( - self.predict_name, - self.signature_field_name, - token, - is_last_chunk=self.stream_end, + self.predict_name, self.signature_field_name, token, is_last_chunk=self.stream_end ) + except ValueError: + pass + + try: + parsed = jiter.from_json( + self.json_adapter_state["field_accumulated_messages"].encode("utf-8"), + partial_mode="trailing-strings", + ) + if len(parsed) > 1: + # If partial json parsing finds a second key, that means the streaming for the field we are listening to + # is finished. + self.stream_end = True + last_token = self.flush() + + keys = list(parsed.keys()) + next_field_name = None + for key in keys: + if key != self.signature_field_name: + next_field_name = key + break + + last_token_index = last_token.find(next_field_name) + token = token + last_token[:last_token_index] if token else last_token[:last_token_index] + except ValueError: + pass + + if token: + return StreamResponse( + self.predict_name, + self.signature_field_name, + token, + is_last_chunk=self.stream_end, + ) + + def _default_handle_stream_chunk(self, token: str, end_identifier: str) -> StreamResponse | None: + concat_message = "".join(self.field_end_queue.queue).strip() + + if re.search(end_identifier, concat_message): + # The next field is identified, we can end the stream and flush out all tokens in the buffer. + self.stream_end = True + last_token = self.flush() + token = token + last_token if token else last_token + token = token.rstrip() # Remove the trailing \n\n + + if token: + return StreamResponse( + self.predict_name, + self.signature_field_name, + token, + is_last_chunk=self.stream_end, + ) def flush(self) -> str: """Flush all tokens in the field end queue. @@ -231,12 +306,7 @@ def flush(self) -> str: last_tokens = "".join(self.field_end_queue.queue) self.field_end_queue = Queue() if isinstance(settings.adapter, JSONAdapter): - match = re.search(r'",|"\s*}', last_tokens) - if match: - boundary_index = match.start() - else: - boundary_index = len(last_tokens) - return last_tokens[:boundary_index] + return last_tokens elif isinstance(settings.adapter, XMLAdapter): boundary_index = last_tokens.find(f"") if boundary_index == -1: @@ -314,13 +384,6 @@ def find_predictor_for_stream_listeners( f"Signature field {field_name} is not unique in the program, cannot automatically determine which " "predictor to use for streaming. Please specify the predictor to listen to." ) - - if not _is_streamable(field_info.annotation): - raise ValueError( - f"Stream listener can only be applied to string or subclass of `dspy.Type` that has `is_streamable() == True`, " - f"but your field {field_name} is of type {field_info.annotation}." - ) - field_name_to_named_predictor[field_name] = (name, predictor) predict_id_to_listener = defaultdict(list) @@ -337,13 +400,3 @@ def find_predictor_for_stream_listeners( listener.predict_name, listener.predict = field_name_to_named_predictor[listener.signature_field_name] predict_id_to_listener[id(listener.predict)].append(listener) return predict_id_to_listener - - -def _is_streamable(field_type: type | None) -> bool: - if field_type is None: - return False - if field_type is str: - return True - if issubclass(field_type, Type): - return field_type.is_streamable() - return False diff --git a/tests/streaming/test_streaming.py b/tests/streaming/test_streaming.py index 89a84250ff..a1e75be5dd 100644 --- a/tests/streaming/test_streaming.py +++ b/tests/streaming/test_streaming.py @@ -4,6 +4,7 @@ from unittest import mock from unittest.mock import AsyncMock +import pydantic import pytest from asyncer import syncify from litellm.types.utils import Delta, ModelResponseStream, StreamingChoices @@ -11,7 +12,7 @@ import dspy from dspy.adapters.types import Type from dspy.experimental import Citations, Document -from dspy.streaming import StatusMessage, StatusMessageProvider, streaming_response +from dspy.streaming import StatusMessage, StatusMessageProvider, StreamResponse, streaming_response @pytest.mark.anyio @@ -455,7 +456,7 @@ async def gpt_4o_mini_stream_1(*args, **kwargs): yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='{"'))]) yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="answer"))]) yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='":'))]) - yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="To"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='"To'))]) yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" get"))]) yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" to"))]) yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" the"))]) @@ -475,8 +476,8 @@ async def gpt_4o_mini_stream_2(*args, **kwargs): yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='{"'))]) yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="jud"))]) yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="gement"))]) - yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='":"'))]) - yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="The"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='":'))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='"The'))]) yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" answer"))]) yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" is"))]) yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" humorous"))]) @@ -515,7 +516,7 @@ async def gpt_4o_mini_stream_2(*args, **kwargs): assert all_chunks[0].predict_name == "predict1" assert all_chunks[0].signature_field_name == "answer" - assert all_chunks[0].chunk == "To" + assert all_chunks[0].chunk == '"To' assert all_chunks[1].chunk == " get" assert all_chunks[2].chunk == " to" assert all_chunks[3].chunk == " the" @@ -525,12 +526,12 @@ async def gpt_4o_mini_stream_2(*args, **kwargs): assert all_chunks[7].chunk == " the" assert all_chunks[8].chunk == " frying" assert all_chunks[9].chunk == " pan" - assert all_chunks[10].chunk == "!" + assert all_chunks[10].chunk == '!"' assert all_chunks[10].is_last_chunk is True assert all_chunks[11].predict_name == "predict2" assert all_chunks[11].signature_field_name == "judgement" - assert all_chunks[11].chunk == "The" + assert all_chunks[11].chunk == '"The' assert all_chunks[12].chunk == " answer" assert all_chunks[13].chunk == " is" assert all_chunks[14].chunk == " humorous" @@ -544,7 +545,7 @@ async def gpt_4o_mini_stream_2(*args, **kwargs): assert all_chunks[22].chunk == " classic" assert all_chunks[23].chunk == " joke" assert all_chunks[24].chunk == " format" - assert all_chunks[25].chunk == "." + assert all_chunks[25].chunk == '."' assert all_chunks[25].is_last_chunk is True @@ -743,11 +744,14 @@ async def gemini_stream_2(*args, **kwargs): assert all_chunks[0].predict_name == "predict1" assert all_chunks[0].signature_field_name == "answer" - assert all_chunks[0].chunk == "To get to the other side... of the cutting board!" + + assert all_chunks[0].chunk == '"To get to the other side... of the cutting board!"' assert all_chunks[1].predict_name == "predict2" assert all_chunks[1].signature_field_name == "judgement" - assert all_chunks[1].chunk == "The answer provides a humorous and relevant punchline to the classic joke setup." + assert ( + all_chunks[1].chunk == '"The answer provides a humorous and relevant punchline to the classic joke setup."' + ) @pytest.mark.anyio @@ -1121,7 +1125,349 @@ async def citation_stream(*args, **kwargs): assert final_prediction is not None assert hasattr(final_prediction, "answer") assert hasattr(final_prediction, "citations") - assert final_prediction.answer == "According to the references, water boils at 100°C." + + +# Test Pydantic Models +class SimpleResponse(pydantic.BaseModel): + message: str + status: str + + +class NestedResponse(pydantic.BaseModel): + title: str + content: dict + metadata: SimpleResponse + + +class ComplexResponse(pydantic.BaseModel): + items: list[str] + settings: dict[str, str] + active: bool + + +@pytest.mark.anyio +async def test_chat_adapter_simple_pydantic_streaming(): + """Test ChatAdapter streaming with a simple pydantic model.""" + + class TestSignature(dspy.Signature): + question: str = dspy.InputField() + response: SimpleResponse = dspy.OutputField() + + class MyProgram(dspy.Module): + def __init__(self): + self.predict = dspy.Predict(TestSignature) + + def forward(self, question, **kwargs): + return self.predict(question=question, **kwargs) + + async def chat_stream(*args, **kwargs): + # Simulate streaming of a pydantic model via ChatAdapter format + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="[[ ##"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" response"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ## ]]\n\n"))]) + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='{"message": "Hello'))] + ) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=' world!"'))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=', "status":'))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=' "success"}'))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="\n\n[[ ##"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" completed"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ## ]]"))]) + + program = dspy.streamify( + MyProgram(), + stream_listeners=[ + dspy.streaming.StreamListener(signature_field_name="response"), + ], + ) + + with mock.patch("litellm.acompletion", side_effect=chat_stream): + with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.ChatAdapter()): + output = program(question="Say hello") + chunks = [] + async for value in output: + if isinstance(value, StreamResponse): + chunks.append(value) + + # Verify we got chunks for the pydantic field + assert len(chunks) > 0 + assert chunks[0].signature_field_name == "response" + + # Combine all chunks to verify the content + full_content = "".join(chunk.chunk for chunk in chunks) + assert "Hello world!" in full_content + assert "success" in full_content + + +@pytest.mark.anyio +async def test_chat_adapter_nested_pydantic_streaming(): + """Test ChatAdapter streaming with nested pydantic model.""" + + class TestSignature(dspy.Signature): + question: str = dspy.InputField() + response: NestedResponse = dspy.OutputField() + + async def nested_stream(*args, **kwargs): + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="[[ ## response ## ]]\n\n"))] + ) + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='{"title": "Test"'))] + ) + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=', "content": {"key": "value"}'))] + ) + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=', "metadata": {"message": "nested"'))] + ) + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=', "status": "ok"}}'))] + ) + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="\n\n[[ ## completed ## ]]"))] + ) + + program = dspy.streamify( + dspy.Predict(TestSignature), + stream_listeners=[ + dspy.streaming.StreamListener(signature_field_name="response"), + ], + ) + + with mock.patch("litellm.acompletion", side_effect=nested_stream): + with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.ChatAdapter()): + output = program(question="Generate nested response") + chunks = [] + async for value in output: + if isinstance(value, StreamResponse): + chunks.append(value) + + assert len(chunks) > 0 + full_content = "".join(chunk.chunk for chunk in chunks) + assert "nested" in full_content + assert "Test" in full_content + + +@pytest.mark.anyio +async def test_chat_adapter_mixed_fields_streaming(): + """Test ChatAdapter streaming with both pydantic and string fields.""" + + class TestSignature(dspy.Signature): + question: str = dspy.InputField() + summary: str = dspy.OutputField() + details: SimpleResponse = dspy.OutputField() + + async def mixed_stream(*args, **kwargs): + # First output field (summary - string) + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="[[ ## summary ## ]]\n\n"))] + ) + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="This is a summary"))] + ) + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" of the response"))] + ) + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="\n\n[[ ## details ## ]]\n\n"))] + ) + # Second output field (details - pydantic) + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='{"message": "Detailed info"'))] + ) + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=', "status": "complete"}'))] + ) + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="\n\n[[ ## completed ## ]]"))] + ) + + program = dspy.streamify( + dspy.Predict(TestSignature), + stream_listeners=[ + dspy.streaming.StreamListener(signature_field_name="summary"), + dspy.streaming.StreamListener(signature_field_name="details"), + ], + ) + + with mock.patch("litellm.acompletion", side_effect=mixed_stream): + with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.ChatAdapter()): + output = program(question="Generate mixed response") + summary_chunks = [] + details_chunks = [] + async for value in output: + if isinstance(value, StreamResponse): + if value.signature_field_name == "summary": + summary_chunks.append(value) + elif value.signature_field_name == "details": + details_chunks.append(value) + + # Verify both field types were streamed + assert len(summary_chunks) > 0 + assert len(details_chunks) > 0 + + summary_content = "".join(chunk.chunk for chunk in summary_chunks) + details_content = "".join(chunk.chunk for chunk in details_chunks) + + assert "summary" in summary_content + assert "Detailed info" in details_content + + +@pytest.mark.anyio +async def test_json_adapter_simple_pydantic_streaming(): + """Test JSONAdapter streaming with a simple pydantic model.""" + + class TestSignature(dspy.Signature): + question: str = dspy.InputField() + response: SimpleResponse = dspy.OutputField() + + async def json_stream(*args, **kwargs): + # Simulate JSON streaming with proper bracket balance tracking + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='{"'))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='response"'))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=":"))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='{"message"'))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=': "Hello'))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=' JSON!"'))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=', "status"'))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=': "ok"}'))]) + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="}"))] + ) # Close main object + + program = dspy.streamify( + dspy.Predict(TestSignature), + stream_listeners=[ + dspy.streaming.StreamListener(signature_field_name="response"), + ], + ) + + with mock.patch("litellm.acompletion", side_effect=json_stream): + with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter()): + output = program(question="Say hello in JSON") + chunks = [] + async for value in output: + if isinstance(value, StreamResponse): + chunks.append(value) + + assert len(chunks) > 0 + assert chunks[0].signature_field_name == "response" + + full_content = "".join(chunk.chunk for chunk in chunks) + assert "Hello JSON!" in full_content + + +@pytest.mark.anyio +async def test_json_adapter_bracket_balance_detection(): + """Test JSONAdapter correctly detects field completion using bracket balance.""" + + class TestSignature(dspy.Signature): + question: str = dspy.InputField() + response: ComplexResponse = dspy.OutputField() + + async def complex_json_stream(*args, **kwargs): + # Test nested objects and arrays for bracket counting + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='{"'))]) + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='response": {'))] + ) # +1 bracket + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='"items": ["a"'))]) + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=', "b"], '))]) + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='"settings": {"key"'))] + ) # +1 bracket + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=': "value"}, '))] + ) # -1 bracket + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='"active": true}'))] + ) # -1 bracket (should end field) + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="}"))] + ) # Close main object + + program = dspy.streamify( + dspy.Predict(TestSignature), + stream_listeners=[ + dspy.streaming.StreamListener(signature_field_name="response"), + ], + ) + + with mock.patch("litellm.acompletion", side_effect=complex_json_stream): + with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter()): + output = program(question="Generate complex JSON") + chunks = [] + async for value in output: + if isinstance(value, StreamResponse): + chunks.append(value) + + assert len(chunks) > 0 + # Check that the last chunk is marked as the last + assert chunks[-1].is_last_chunk is True + + full_content = "".join(chunk.chunk for chunk in chunks) + + assert "items" in full_content + assert "settings" in full_content + + +@pytest.mark.anyio +async def test_json_adapter_multiple_fields_detection(): + """Test JSONAdapter correctly detects when next field starts.""" + + class TestSignature(dspy.Signature): + question: str = dspy.InputField() + first: SimpleResponse = dspy.OutputField() + second: SimpleResponse = dspy.OutputField() + + async def multi_field_stream(*args, **kwargs): + # First field + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='{"first": {'))]) + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='"message": "first response"'))] + ) + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=', "status": "ok"}'))] + ) + # Second field starts + yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=', "second": {'))]) + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='"message": "second response"'))] + ) + yield ModelResponseStream( + model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=', "status": "done"}}'))] + ) + + program = dspy.streamify( + dspy.Predict(TestSignature), + stream_listeners=[ + dspy.streaming.StreamListener(signature_field_name="first"), + dspy.streaming.StreamListener(signature_field_name="second"), + ], + ) + + with mock.patch("litellm.acompletion", side_effect=multi_field_stream): + with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter()): + output = program(question="Generate two responses") + first_chunks = [] + second_chunks = [] + async for value in output: + if isinstance(value, StreamResponse): + if value.signature_field_name == "first": + first_chunks.append(value) + elif value.signature_field_name == "second": + second_chunks.append(value) + + # Verify both fields were detected and streamed + assert len(first_chunks) > 0 + assert len(second_chunks) > 0 + + first_content = "".join(chunk.chunk for chunk in first_chunks) + second_content = "".join(chunk.chunk for chunk in second_chunks) + + assert "first response" in first_content + assert "second response" in second_content def test_stream_listener_could_form_end_identifier_chat_adapter():