Skip to content

Commit a0b2155

Browse files
fix: add fallback mechanism on missing end marker during streaming (#8890)
* fix: add fallback mechanism on missing end marker during streaming * address PR review comments * fix tests --------- Co-authored-by: chenmoneygithub <[email protected]>
1 parent 965da1d commit a0b2155

File tree

3 files changed

+136
-24
lines changed

3 files changed

+136
-24
lines changed

dspy/streaming/streamify.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,11 @@ async def async_streamer(*args, **kwargs):
193193
elif isinstance(value, StatusMessage):
194194
yield value
195195
elif isinstance(value, Prediction):
196+
# Flush remaining buffered tokens before yielding the Prediction instance
197+
for listener in stream_listeners:
198+
if final_chunk := listener.finalize():
199+
yield final_chunk
200+
196201
if include_final_prediction_in_output_stream:
197202
yield value
198203
elif (

dspy/streaming/streaming_listener.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,41 @@ def flush(self) -> str:
208208
return last_tokens[:boundary_index]
209209
elif isinstance(settings.adapter, ChatAdapter) or settings.adapter is None:
210210
boundary_index = last_tokens.find("[[")
211+
if boundary_index == -1:
212+
boundary_index = len(last_tokens)
211213
return last_tokens[:boundary_index]
212214
else:
213215
raise ValueError(
214216
f"Unsupported adapter for streaming: {settings.adapter}, please use one of the following adapters: "
215217
f"{', '.join([a.__name__ for a in ADAPTER_SUPPORT_STREAMING])}"
216218
)
217219

220+
def finalize(self) -> StreamResponse | None:
221+
"""Finalize the stream and flush any remaining buffered tokens.
222+
223+
This should be called when the stream ends.
224+
It ensures no tokens are lost from the buffer and marks the final chunk appropriately.
225+
226+
Returns:
227+
A StreamResponse with the remaining buffered tokens and is_last_chunk=True,
228+
or None if there are no buffered tokens or the stream hasn't started.
229+
"""
230+
if self.stream_end or not self.stream_start:
231+
# Stream already ended or never started, nothing to finalize
232+
return None
233+
234+
self.stream_end = True
235+
if self.field_end_queue.qsize() > 0:
236+
token = self.flush()
237+
if token:
238+
return StreamResponse(
239+
self.predict_name,
240+
self.signature_field_name,
241+
token,
242+
is_last_chunk=True,
243+
)
244+
return None
245+
218246
@property
219247
def _output_type(self) -> type | None:
220248
try:
@@ -224,7 +252,7 @@ def _output_type(self) -> type | None:
224252

225253

226254

227-
def find_predictor_for_stream_listeners(program: "Module", stream_listeners: list[StreamListener]):
255+
def find_predictor_for_stream_listeners(program: "Module", stream_listeners: list[StreamListener]) -> dict[int, list[StreamListener]]:
228256
"""Find the predictor for each stream listener.
229257
230258
This is a utility function to automatically find the predictor for each stream listener. It is used when some

tests/streaming/test_streaming.py

Lines changed: 102 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,10 @@ def __call__(self, x: str, **kwargs):
166166

167167
assert all_chunks[0].predict_name == "predict1"
168168
assert all_chunks[0].signature_field_name == "answer"
169-
170-
assert all_chunks[-1].predict_name == "predict2"
171-
assert all_chunks[-1].signature_field_name == "judgement"
169+
# The last chunk can be from either predictor because sometimes small LMs miss the `[[ ## completed ## ]]` marker,
170+
# which results in an extra chunk that flushes out the buffer.
171+
assert all_chunks[-2].predict_name == "predict2"
172+
assert all_chunks[-2].signature_field_name == "judgement"
172173

173174

174175
@pytest.mark.anyio
@@ -299,10 +300,11 @@ def __call__(self, x: str, **kwargs):
299300
assert all_chunks[0].predict_name == "predict1"
300301
assert all_chunks[0].signature_field_name == "answer"
301302
assert all_chunks[0].is_last_chunk is False
302-
303-
assert all_chunks[-1].predict_name == "predict2"
304-
assert all_chunks[-1].signature_field_name == "judgement"
305-
assert all_chunks[-1].is_last_chunk is True
303+
# The last chunk can be from either predictor because sometimes small LMs miss the `[[ ## completed ## ]]` marker,
304+
# which results in an extra chunk that flushes out the buffer.
305+
assert all_chunks[-2].predict_name == "predict2"
306+
assert all_chunks[-2].signature_field_name == "judgement"
307+
assert all_chunks[-2].is_last_chunk is True
306308

307309

308310
def test_sync_status_streaming():
@@ -599,6 +601,71 @@ async def gemini_stream_2(*args, **kwargs):
599601
)
600602

601603

604+
@pytest.mark.anyio
605+
async def test_stream_listener_missing_completion_marker_chat_adapter():
606+
"""Test that streaming works correctly when LLM response omits a final completion marker.
607+
608+
This test verifies that:
609+
1. All tokens are yielded including those in the buffer
610+
2. The last chunk is properly marked with is_last_chunk=True
611+
3. No tokens are lost when the completion marker is missing
612+
"""
613+
614+
class MyProgram(dspy.Module):
615+
def __init__(self):
616+
super().__init__()
617+
self.predict = dspy.Predict("question->answer")
618+
619+
def forward(self, question, **kwargs):
620+
return self.predict(question=question, **kwargs)
621+
622+
async def incomplete_stream(*args, **kwargs):
623+
"""Stream that includes start marker but MISSING completion marker"""
624+
# Start marker
625+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="[[ ##"))])
626+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" answer"))])
627+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ## ]]\n\n"))])
628+
629+
# Content tokens - more than 10 to ensure buffering happens
630+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="This"))])
631+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" is"))])
632+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" a"))])
633+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" test"))])
634+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" response"))])
635+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" with"))])
636+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" many"))])
637+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" tokens"))])
638+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" to"))])
639+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ensure"))])
640+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" buffering"))])
641+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" works"))])
642+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" correctly"))])
643+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="."))])
644+
# NO COMPLETION MARKER
645+
646+
with mock.patch("litellm.acompletion", side_effect=incomplete_stream):
647+
program = dspy.streamify(
648+
MyProgram(),
649+
stream_listeners=[
650+
dspy.streaming.StreamListener(signature_field_name="answer"),
651+
],
652+
)
653+
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.ChatAdapter()):
654+
output = program(question="Test question")
655+
all_chunks = []
656+
final_prediction = None
657+
async for value in output:
658+
if isinstance(value, dspy.streaming.StreamResponse):
659+
all_chunks.append(value)
660+
elif isinstance(value, dspy.Prediction):
661+
final_prediction = value
662+
663+
full_content = "".join([chunk.chunk for chunk in all_chunks])
664+
expected_content = "This is a test response with many tokens to ensure buffering works correctly."
665+
assert full_content == expected_content
666+
assert final_prediction.answer == expected_content
667+
668+
602669
@pytest.mark.anyio
603670
async def test_stream_listener_returns_correct_chunk_json_adapter_untokenized_stream():
604671
class MyProgram(dspy.Module):
@@ -915,9 +982,10 @@ async def stream(*args, **kwargs):
915982
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ##"))])
916983
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ]]"))])
917984

918-
919985
with mock.patch("litellm.acompletion", side_effect=stream):
920-
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.ChatAdapter(native_response_types=[CustomType])):
986+
with dspy.context(
987+
lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.ChatAdapter(native_response_types=[CustomType])
988+
):
921989
output = program(question="why did a chicken cross the kitchen?")
922990
all_chunks = []
923991
async for value in output:
@@ -934,6 +1002,7 @@ async def stream(*args, **kwargs):
9341002
async def test_streaming_with_citations():
9351003
class AnswerWithSources(dspy.Signature):
9361004
"""Answer questions using provided documents with citations."""
1005+
9371006
documents: list[Document] = dspy.InputField()
9381007
question: str = dspy.InputField()
9391008
answer: str = dspy.OutputField()
@@ -964,19 +1033,26 @@ async def citation_stream(*args, **kwargs):
9641033
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="g"))])
9651034
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" to "))])
9661035
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="the references,"))])
967-
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(
968-
content="",
969-
provider_specific_fields={
970-
"citation": {
971-
"type": "char_location",
972-
"cited_text": "water boils at 100°C",
973-
"document_index": 0,
974-
"document_title": "Physics Facts",
975-
"start_char_index": 0,
976-
"end_char_index": 19
977-
}
978-
}
979-
))])
1036+
yield ModelResponseStream(
1037+
model="claude",
1038+
choices=[
1039+
StreamingChoices(
1040+
delta=Delta(
1041+
content="",
1042+
provider_specific_fields={
1043+
"citation": {
1044+
"type": "char_location",
1045+
"cited_text": "water boils at 100°C",
1046+
"document_index": 0,
1047+
"document_title": "Physics Facts",
1048+
"start_char_index": 0,
1049+
"end_char_index": 19,
1050+
}
1051+
},
1052+
)
1053+
)
1054+
],
1055+
)
9801056
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" water"))])
9811057
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" boils"))])
9821058
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" at"))])
@@ -1000,7 +1076,10 @@ async def citation_stream(*args, **kwargs):
10001076
# Create test documents
10011077
docs = [Document(data="Water boils at 100°C at standard pressure.", title="Physics Facts")]
10021078

1003-
with dspy.context(lm=dspy.LM("anthropic/claude-3-5-sonnet-20241022", cache=False), adapter=dspy.ChatAdapter(native_response_types=[Citations])):
1079+
with dspy.context(
1080+
lm=dspy.LM("anthropic/claude-3-5-sonnet-20241022", cache=False),
1081+
adapter=dspy.ChatAdapter(native_response_types=[Citations]),
1082+
):
10041083
output = program(documents=docs, question="What temperature does water boil?")
10051084
citation_chunks = []
10061085
answer_chunks = []

0 commit comments

Comments
 (0)