Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 96 additions & 43 deletions dspy/streaming/streaming_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from queue import Queue
from typing import TYPE_CHECKING, Any

import jiter
from litellm import ModelResponseStream

from dspy.adapters.chat_adapter import ChatAdapter
Expand Down Expand Up @@ -49,6 +50,8 @@ def __init__(
self.cache_hit = False
self.allow_reuse = allow_reuse

self.json_adapter_state = {"field_accumulated_tokens": ""}
Copy link
Collaborator

@TomeHirata TomeHirata Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we plan to introduce other keys to self.json_adapter_state? or can we flatten the structure?


self.adapter_identifiers = {
"ChatAdapter": {
"start_identifier": f"[[ ## {self.signature_field_name} ## ]]",
Expand All @@ -62,7 +65,7 @@ def __init__(
"end_identifier": re.compile(r"\w*\"(,|\s*})"),
"start_indicator": '"',
"end_pattern_prefixes": ['"', '",', '" ', '"}'],
"end_pattern_contains": None,
"end_pattern_contains": "}",
},
"XMLAdapter": {
"start_identifier": f"<{self.signature_field_name}>",
Expand Down Expand Up @@ -126,6 +129,7 @@ def receive(self, chunk: ModelResponseStream):
self.cache_hit = False
self.field_start_queue = []
self.field_end_queue = Queue()
self.json_adapter_state["field_accumulated_tokens"] = ""
self.stream_start = False
else:
return
Expand All @@ -147,7 +151,7 @@ def receive(self, chunk: ModelResponseStream):
is_last_chunk=self.stream_end,
)

if chunk_message and start_identifier in chunk_message:
if chunk_message and start_identifier in chunk_message and not isinstance(settings.adapter, JSONAdapter):
# If the cache is hit, the chunk_message could be the full response. When it happens we can
# directly end the stream listening. In some models like gemini, each stream chunk can be multiple
# tokens, so it's possible that response only has one chunk, we also fall back to this logic.
Expand Down Expand Up @@ -180,10 +184,13 @@ def receive(self, chunk: ModelResponseStream):
# Keep the part after the start_identifier from the concat_message, we need to write it to the buffer.
value_start_index = concat_message.find(start_identifier) + len(start_identifier)
chunk_message = concat_message[value_start_index:].lstrip()
if isinstance(settings.adapter, JSONAdapter) and chunk_message.startswith('"'):
# For JSONAdapter, we need to remove the leading ". We cannot do this with the start_identifier
# because there could be a few splitters between ':' and '"', e.g., '"name": "value"'.
chunk_message = chunk_message[1:]

if isinstance(settings.adapter, JSONAdapter):
# For JSONAdapter, we rely on partial json parsing to detect the end of the field we are listening
# to, so we need to maintain a few extra states to help us with that.
# We add an extra "{" to the beginning of the field_accumulated_tokens, so we can detect the
# appearance of the next key.
self.json_adapter_state["field_accumulated_tokens"] += "{" + start_identifier

elif self._buffered_message_end_with_start_identifier(concat_message.strip(), start_identifier):
# If the buffered message ends with part of the start_identifier, we keep looking for the
Expand All @@ -196,30 +203,98 @@ def receive(self, chunk: ModelResponseStream):

if self.stream_start and chunk_message:
# The stream is started, we keep returning the token until we see the start of the next field.
token = None
self.field_end_queue.put(chunk_message)

token = None
concat_message = "".join(self.field_end_queue.queue).strip()
if re.search(end_identifier, concat_message):
# The next field is identified, we can end the stream and flush out all tokens in the buffer.
self.stream_end = True
token = self.flush()
token = token.rstrip() # Remove the trailing \n\n
elif not self._could_form_end_identifier(concat_message, adapter_name):

if not self._could_form_end_identifier(concat_message, adapter_name):
# Buffer cannot form end identifier, safe to flush out the tokens in the buffer.
token = self.flush()
elif self.field_end_queue.qsize() > 10:
# Buffer could form end identifier, but we've exceeded max buffer size
# Yield the oldest token to prevent unbounded buffering
# We keep the last 10 tokens in the buffer if they can potentially form the end_identifier to avoid
# sending the DSPy bolilerplate tokens to users. 10 is a heuristic number that is sufficient to capture
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'bolilerplate' to 'boilerplate'.

Suggested change
# sending the DSPy bolilerplate tokens to users. 10 is a heuristic number that is sufficient to capture
# sending the DSPy boilerplate tokens to users. 10 is a heuristic number that is sufficient to capture

Copilot uses AI. Check for mistakes.
# the end_identifier for all LMs.
token = self.field_end_queue.get()

if token:
if isinstance(settings.adapter, JSONAdapter):
# JSONAdapter uses partial json parsing to detect the end of the field we are listening to, instead of
# relying on the end_identifier.
return self._json_adapter_handle_stream_chunk(token, chunk_message)
else:
# Other adapters rely on the end_identifier to detect the end of the field we are listening to.
return self._default_handle_stream_chunk(token, end_identifier)

def _json_adapter_handle_stream_chunk(self, token: str, chunk_message: str) -> StreamResponse | None:
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mixing implicit and explicit returns may indicate an error, as implicit returns always return None.

Copilot uses AI. Check for mistakes.
self.json_adapter_state["field_accumulated_tokens"] += chunk_message
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we accumulate chunk_message instead of token?

if self.json_adapter_state["field_accumulated_tokens"].rstrip().endswith("}"):
# When the accumulated tokens ends with a curly b racket, that means the streaming for the predict we are
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'b racket' (with extra spaces) to 'bracket'.

Suggested change
# When the accumulated tokens ends with a curly b racket, that means the streaming for the predict we are
# When the accumulated tokens ends with a curly bracket, that means the streaming for the predict we are

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# When the accumulated tokens ends with a curly b racket, that means the streaming for the predict we are
# When the accumulated tokens end with a curly bracket, that means the streaming for the prediction we are

# listening to is probably finished, we need to run a check and decide whether to end the stream.
try:
# If the parse doesn't raise an error, that means the accumulated tokens is a valid json object. Because
# we add an extra "{" to the beginning of the field_accumulated_tokens, so we know the streaming is
# finished.
jiter.from_json(self.json_adapter_state["field_accumulated_tokens"].encode("utf-8"))
self.stream_end = True
last_token = self.flush()
right_curly_bracket_index = last_token.rfind("}")
token = (
token + last_token[:right_curly_bracket_index] if token else last_token[:right_curly_bracket_index]
)
return StreamResponse(
self.predict_name,
self.signature_field_name,
token,
is_last_chunk=self.stream_end,
self.predict_name, self.signature_field_name, token, is_last_chunk=self.stream_end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So overall we will return a raw string chunk so the deserialization needs to happen on the caller side?

)
except Exception:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we limit this to be ValueError?

pass

try:
parsed = jiter.from_json(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting, can't we just count the number of { and }?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed offline, please see the new implementation for a more robust solution.

self.json_adapter_state["field_accumulated_tokens"].encode("utf-8"),
partial_mode="trailing-strings",
)
if len(parsed) > 1:
# If partial json parsing finds a second key, that means the streaming for the field we are listening to
# is finished.
self.stream_end = True
last_token = self.flush()

keys = list(parsed.keys())
Copy link
Collaborator

@TomeHirata TomeHirata Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is parsed.keys ordered based on the key order in the json string?

next_field_name = None
for key in keys:
if key != self.signature_field_name:
next_field_name = key
break

last_token_index = last_token.find(next_field_name)
token = token + last_token[:last_token_index] if token else last_token[:last_token_index]
except ValueError:
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'except' clause does nothing but pass and there is no explanatory comment.

Copilot uses AI. Check for mistakes.
pass

if token:
return StreamResponse(
self.predict_name,
self.signature_field_name,
token,
is_last_chunk=self.stream_end,
)

def _default_handle_stream_chunk(self, token: str, end_identifier: str) -> StreamResponse | None:
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mixing implicit and explicit returns may indicate an error, as implicit returns always return None.

Copilot uses AI. Check for mistakes.
concat_message = "".join(self.field_end_queue.queue).strip()

if re.search(end_identifier, concat_message):
# The next field is identified, we can end the stream and flush out all tokens in the buffer.
self.stream_end = True
last_token = self.flush()
token = token + last_token if token else last_token
token = token.rstrip() # Remove the trailing \n\n

if token:
return StreamResponse(
self.predict_name,
self.signature_field_name,
token,
is_last_chunk=self.stream_end,
)

def flush(self) -> str:
"""Flush all tokens in the field end queue.
Expand All @@ -231,12 +306,7 @@ def flush(self) -> str:
last_tokens = "".join(self.field_end_queue.queue)
self.field_end_queue = Queue()
if isinstance(settings.adapter, JSONAdapter):
match = re.search(r'",|"\s*}', last_tokens)
if match:
boundary_index = match.start()
else:
boundary_index = len(last_tokens)
return last_tokens[:boundary_index]
return last_tokens
elif isinstance(settings.adapter, XMLAdapter):
boundary_index = last_tokens.find(f"</{self.signature_field_name}>")
if boundary_index == -1:
Expand Down Expand Up @@ -314,13 +384,6 @@ def find_predictor_for_stream_listeners(
f"Signature field {field_name} is not unique in the program, cannot automatically determine which "
"predictor to use for streaming. Please specify the predictor to listen to."
)

if not _is_streamable(field_info.annotation):
raise ValueError(
f"Stream listener can only be applied to string or subclass of `dspy.Type` that has `is_streamable() == True`, "
f"but your field {field_name} is of type {field_info.annotation}."
)

field_name_to_named_predictor[field_name] = (name, predictor)

predict_id_to_listener = defaultdict(list)
Expand All @@ -337,13 +400,3 @@ def find_predictor_for_stream_listeners(
listener.predict_name, listener.predict = field_name_to_named_predictor[listener.signature_field_name]
predict_id_to_listener[id(listener.predict)].append(listener)
return predict_id_to_listener


def _is_streamable(field_type: type | None) -> bool:
if field_type is None:
return False
if field_type is str:
return True
if issubclass(field_type, Type):
return field_type.is_streamable()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we delete is_streamable method of Type?

return False
Loading