Skip to content

Commit 1bf4859

Browse files
committed
fix: esure Bedrock Converse observation nests properly in chat trace
1 parent 55ae73e commit 1bf4859

File tree

1 file changed

+124
-138
lines changed

1 file changed

+124
-138
lines changed

src/api/models/bedrock.py

Lines changed: 124 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,6 @@ def validate(self, chat_request: ChatRequest):
248248
detail=error,
249249
)
250250

251-
@observe(as_type="generation", name="Bedrock Converse")
252251
async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
253252
"""Common logic for invoke bedrock models"""
254253
if DEBUG:
@@ -259,29 +258,6 @@ async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
259258
if DEBUG:
260259
logger.info("Bedrock request: " + json.dumps(str(args)))
261260

262-
# Extract model metadata for Langfuse
263-
args_clone = args.copy()
264-
messages = args_clone.get('messages', [])
265-
model_id = args_clone.get('modelId', 'unknown')
266-
model_parameters = {
267-
**args_clone.get('inferenceConfig', {}),
268-
**args_clone.get('additionalModelRequestFields', {})
269-
}
270-
271-
# Update Langfuse generation with input metadata
272-
langfuse_context.update_current_observation(
273-
input=messages,
274-
model=model_id,
275-
model_parameters=model_parameters,
276-
metadata={
277-
'system': args_clone.get('system', []),
278-
'toolConfig': args_clone.get('toolConfig', {}),
279-
'stream': stream
280-
}
281-
)
282-
if DEBUG:
283-
logger.info(f"Langfuse: Updated observation with input - model={model_id}, stream={stream}, messages_count={len(messages)}")
284-
285261
try:
286262
if stream:
287263
# Run the blocking boto3 call in a thread pool
@@ -291,93 +267,99 @@ async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
291267
else:
292268
# Run the blocking boto3 call in a thread pool
293269
response = await run_in_threadpool(bedrock_runtime.converse, **args)
294-
295-
# For non-streaming, extract response metadata immediately
296-
if response and not stream:
297-
output_message = response.get("output", {}).get("message", {})
298-
usage = response.get("usage", {})
299-
300-
# Build metadata
301-
metadata = {
302-
"stopReason": response.get("stopReason"),
303-
"ResponseMetadata": response.get("ResponseMetadata", {})
304-
}
305-
306-
# Check for reasoning content in response
307-
has_reasoning = False
308-
reasoning_text = ""
309-
if output_message and "content" in output_message:
310-
for content_block in output_message.get("content", []):
311-
if "reasoningContent" in content_block:
312-
has_reasoning = True
313-
reasoning_text = content_block.get("reasoningContent", {}).get("reasoningText", {}).get("text", "")
314-
break
315-
316-
if has_reasoning and reasoning_text:
317-
metadata["has_extended_thinking"] = True
318-
metadata["reasoning_content"] = reasoning_text
319-
metadata["reasoning_tokens_estimate"] = len(reasoning_text) // 4
320-
321-
langfuse_context.update_current_observation(
322-
output=output_message,
323-
usage={
324-
"input": usage.get("inputTokens", 0),
325-
"output": usage.get("outputTokens", 0),
326-
"total": usage.get("totalTokens", 0)
327-
},
328-
metadata=metadata
329-
)
330-
if DEBUG:
331-
logger.info(f"Langfuse: Updated observation with output - "
332-
f"input_tokens={usage.get('inputTokens', 0)}, "
333-
f"output_tokens={usage.get('outputTokens', 0)}, "
334-
f"has_reasoning={has_reasoning}, "
335-
f"stop_reason={response.get('stopReason')}")
336270
except bedrock_runtime.exceptions.ValidationException as e:
337271
error_message = f"Bedrock validation error for model {chat_request.model}: {str(e)}"
338272
logger.error(error_message)
339-
langfuse_context.update_current_observation(level="ERROR", status_message=error_message)
340-
if DEBUG:
341-
logger.info("Langfuse: Updated observation with ValidationException error")
342273
raise HTTPException(status_code=400, detail=str(e))
343274
except bedrock_runtime.exceptions.ThrottlingException as e:
344275
error_message = f"Bedrock throttling for model {chat_request.model}: {str(e)}"
345276
logger.warning(error_message)
346-
langfuse_context.update_current_observation(level="WARNING", status_message=error_message)
347-
if DEBUG:
348-
logger.info("Langfuse: Updated observation with ThrottlingException warning")
349277
raise HTTPException(status_code=429, detail=str(e))
350278
except Exception as e:
351279
error_message = f"Bedrock invocation failed for model {chat_request.model}: {str(e)}"
352280
logger.error(error_message)
353-
langfuse_context.update_current_observation(level="ERROR", status_message=error_message)
354-
if DEBUG:
355-
logger.info("Langfuse: Updated observation with generic Exception error")
356281
raise HTTPException(status_code=500, detail=str(e))
357282
return response
358283

359284
async def chat(self, chat_request: ChatRequest) -> ChatResponse:
360-
"""Default implementation for Chat API."""
361-
285+
"""Default implementation for Chat API.
286+
287+
Note: Works within the parent trace context created by @observe
288+
decorator on chat_completions endpoint. Updates that trace context
289+
with the response data.
290+
"""
362291
message_id = self.generate_message_id()
363-
response = await self._invoke_bedrock(chat_request)
364-
365-
output_message = response["output"]["message"]
366-
input_tokens = response["usage"]["inputTokens"]
367-
output_tokens = response["usage"]["outputTokens"]
368-
finish_reason = response["stopReason"]
369-
370-
chat_response = self._create_response(
371-
model=chat_request.model,
372-
message_id=message_id,
373-
content=output_message["content"],
374-
finish_reason=finish_reason,
375-
input_tokens=input_tokens,
376-
output_tokens=output_tokens,
377-
)
378-
if DEBUG:
379-
logger.info("Proxy response :" + chat_response.model_dump_json())
380-
return chat_response
292+
293+
try:
294+
if DEBUG:
295+
logger.info(f"Langfuse: Starting non-streaming request for model={chat_request.model}")
296+
297+
response = await self._invoke_bedrock(chat_request)
298+
299+
output_message = response["output"]["message"]
300+
input_tokens = response["usage"]["inputTokens"]
301+
output_tokens = response["usage"]["outputTokens"]
302+
finish_reason = response["stopReason"]
303+
304+
# Build metadata including usage info
305+
trace_metadata = {
306+
"model": chat_request.model,
307+
"stream": False,
308+
"stopReason": finish_reason,
309+
"usage": {
310+
"prompt_tokens": input_tokens,
311+
"completion_tokens": output_tokens,
312+
"total_tokens": input_tokens + output_tokens
313+
},
314+
"ResponseMetadata": response.get("ResponseMetadata", {})
315+
}
316+
317+
# Check for reasoning content in response
318+
has_reasoning = False
319+
reasoning_text = ""
320+
if output_message and "content" in output_message:
321+
for content_block in output_message.get("content", []):
322+
if "reasoningContent" in content_block:
323+
has_reasoning = True
324+
reasoning_text = content_block.get("reasoningContent", {}).get("reasoningText", {}).get("text", "")
325+
break
326+
327+
if has_reasoning and reasoning_text:
328+
trace_metadata["has_extended_thinking"] = True
329+
trace_metadata["reasoning_content"] = reasoning_text
330+
trace_metadata["reasoning_tokens_estimate"] = len(reasoning_text) // 4
331+
332+
# Update trace with metadata
333+
langfuse_context.update_current_trace(
334+
metadata=trace_metadata
335+
)
336+
337+
if DEBUG:
338+
logger.info(f"Langfuse: Non-streaming response - "
339+
f"input_tokens={input_tokens}, "
340+
f"output_tokens={output_tokens}, "
341+
f"has_reasoning={has_reasoning}, "
342+
f"stop_reason={finish_reason}")
343+
344+
chat_response = self._create_response(
345+
model=chat_request.model,
346+
message_id=message_id,
347+
content=output_message["content"],
348+
finish_reason=finish_reason,
349+
input_tokens=input_tokens,
350+
output_tokens=output_tokens,
351+
)
352+
if DEBUG:
353+
logger.info("Proxy response :" + chat_response.model_dump_json())
354+
return chat_response
355+
except HTTPException:
356+
# Re-raise HTTPException as-is
357+
raise
358+
except Exception as e:
359+
logger.error("Chat error for model %s: %s", chat_request.model, str(e))
360+
if DEBUG:
361+
logger.info(f"Langfuse: Error in non-streaming - error={str(e)[:100]}")
362+
raise
381363

382364
async def _async_iterate(self, stream):
383365
"""Helper method to convert sync iterator to async iterator"""
@@ -386,10 +368,21 @@ async def _async_iterate(self, stream):
386368
yield chunk
387369

388370
async def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
389-
"""Default implementation for Chat Stream API"""
371+
"""Default implementation for Chat Stream API
372+
373+
Note: For streaming, we work within the parent trace context created by @observe
374+
decorator on chat_completions endpoint. We update that trace context with
375+
streaming data as it arrives.
376+
"""
390377
try:
391378
if DEBUG:
392379
logger.info(f"Langfuse: Starting streaming request for model={chat_request.model}")
380+
381+
# Parse request for metadata to log in parent trace
382+
args = self._parse_request(chat_request)
383+
messages = args.get('messages', [])
384+
model_id = args.get('modelId', 'unknown')
385+
393386
response = await self._invoke_bedrock(chat_request, stream=True)
394387
message_id = self.generate_message_id()
395388
stream = response.get("stream")
@@ -403,8 +396,8 @@ async def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
403396
has_reasoning = False
404397

405398
async for chunk in self._async_iterate(stream):
406-
args = {"model_id": chat_request.model, "message_id": message_id, "chunk": chunk}
407-
stream_response = self._create_response_stream(**args)
399+
args_chunk = {"model_id": chat_request.model, "message_id": message_id, "chunk": chunk}
400+
stream_response = self._create_response_stream(**args_chunk)
408401
if not stream_response:
409402
continue
410403

@@ -438,49 +431,46 @@ async def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
438431
# All other chunks will also include a usage field, but with a null value.
439432
yield self.stream_response_to_bytes(stream_response)
440433

441-
# Update Langfuse with final streaming metadata (both observation and trace)
434+
# Update Langfuse trace with final streaming output
435+
# This updates the parent trace from chat_completions
442436
if final_usage or accumulated_output:
443-
update_params = {}
444-
if accumulated_output:
445-
final_output = "".join(accumulated_output)
446-
update_params["output"] = final_output
447-
if final_usage:
448-
update_params["usage"] = {
449-
"input": final_usage.prompt_tokens,
450-
"output": final_usage.completion_tokens,
451-
"total": final_usage.total_tokens
452-
}
453-
# Build metadata
454-
metadata = {}
455-
if finish_reason:
456-
metadata["finish_reason"] = finish_reason
457-
if has_reasoning and accumulated_reasoning:
458-
reasoning_text = "".join(accumulated_reasoning)
459-
metadata["has_extended_thinking"] = True
460-
metadata["reasoning_content"] = reasoning_text
461-
# Estimate reasoning tokens (rough approximation: ~4 chars per token)
462-
metadata["reasoning_tokens_estimate"] = len(reasoning_text) // 4
463-
if metadata:
464-
update_params["metadata"] = metadata
465-
466-
# Update the child observation (Bedrock Converse)
467-
langfuse_context.update_current_observation(**update_params)
468-
469-
# Also update the parent trace (chat_completion) with final output
437+
final_output = "".join(accumulated_output) if accumulated_output else None
470438
trace_output = {
471439
"message": {
472440
"role": "assistant",
473-
"content": final_output if accumulated_output else None,
441+
"content": final_output,
474442
},
475443
"finish_reason": finish_reason,
476444
}
477-
langfuse_context.update_current_trace(output=trace_output)
445+
446+
# Build metadata including usage info
447+
trace_metadata = {
448+
"model": model_id,
449+
"stream": True,
450+
}
451+
if finish_reason:
452+
trace_metadata["finish_reason"] = finish_reason
453+
if final_usage:
454+
trace_metadata["usage"] = {
455+
"prompt_tokens": final_usage.prompt_tokens,
456+
"completion_tokens": final_usage.completion_tokens,
457+
"total_tokens": final_usage.total_tokens
458+
}
459+
if has_reasoning and accumulated_reasoning:
460+
reasoning_text = "".join(accumulated_reasoning)
461+
trace_metadata["has_extended_thinking"] = True
462+
trace_metadata["reasoning_tokens_estimate"] = len(reasoning_text) // 4
463+
464+
langfuse_context.update_current_trace(
465+
output=trace_output,
466+
metadata=trace_metadata
467+
)
478468

479469
if DEBUG:
480470
output_length = len(accumulated_output)
481-
logger.info(f"Langfuse: Updated observation and trace with streaming output - "
471+
logger.info(f"Langfuse: Updated trace with streaming output - "
482472
f"chunks_count={output_length}, "
483-
f"output_chars={len(final_output) if accumulated_output else 0}, "
473+
f"output_chars={len(final_output) if final_output else 0}, "
484474
f"input_tokens={final_usage.prompt_tokens if final_usage else 'N/A'}, "
485475
f"output_tokens={final_usage.completion_tokens if final_usage else 'N/A'}, "
486476
f"has_reasoning={has_reasoning}, "
@@ -490,21 +480,17 @@ async def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
490480
yield self.stream_response_to_bytes()
491481
self.think_emitted = False # Cleanup
492482
except HTTPException:
493-
# HTTPException already has Langfuse updated in _invoke_bedrock, re-raise it
483+
# Re-raise HTTPException as-is
494484
raise
495485
except Exception as e:
496486
logger.error("Stream error for model %s: %s", chat_request.model, str(e))
497-
# Update Langfuse with error (both observation and trace)
498-
langfuse_context.update_current_observation(
499-
level="ERROR",
500-
status_message=f"Stream error: {str(e)}"
501-
)
487+
# Update Langfuse with error
502488
langfuse_context.update_current_trace(
503489
output={"error": str(e)},
504-
metadata={"error": True}
490+
metadata={"error": True, "error_type": type(e).__name__}
505491
)
506492
if DEBUG:
507-
logger.info(f"Langfuse: Updated observation with streaming error - error={str(e)[:100]}")
493+
logger.info(f"Langfuse: Updated trace with streaming error - error={str(e)[:100]}")
508494
error_event = Error(error=ErrorMessage(message=str(e)))
509495
yield self.stream_response_to_bytes(error_event)
510496

0 commit comments

Comments
 (0)