@@ -359,6 +359,9 @@ def _create_llm_span(
359359 _set_span_attribute (span , SpanAttributes .LLM_SYSTEM , vendor )
360360 _set_span_attribute (span , SpanAttributes .LLM_REQUEST_TYPE , request_type .value )
361361
362+ span_kind = self ._determine_llm_span_kind (serialized )
363+ _set_span_attribute (span , SpanAttributes .TRACELOOP_SPAN_KIND , span_kind .value )
364+
362365 # we already have an LLM span by this point,
363366 # so skip any downstream instrumentation from here
364367 try :
@@ -375,6 +378,72 @@ def _create_llm_span(
375378
376379 return span
377380
381+ def _determine_llm_span_kind (self , serialized : Optional [dict [str , Any ]]) -> TraceloopSpanKindValues :
382+ """Determine the appropriate span kind for LLM operations based on model type."""
383+ if not serialized :
384+ return TraceloopSpanKindValues .GENERATION
385+
386+ class_name = _extract_class_name_from_serialized (serialized )
387+ class_name_lower = class_name .lower ()
388+
389+ if any (keyword in class_name_lower for keyword in ['embedding' , 'embed' ]):
390+ return TraceloopSpanKindValues .EMBEDDING
391+
392+ # Default to generation for other LLM operations
393+ return TraceloopSpanKindValues .GENERATION
394+
395+ def _determine_chain_span_kind (
396+ self ,
397+ serialized : dict [str , Any ],
398+ name : str ,
399+ tags : Optional [list [str ]] = None
400+ ) -> TraceloopSpanKindValues :
401+ if serialized and "id" in serialized :
402+ class_path = serialized ["id" ]
403+ if any ("agent" in part .lower () for part in class_path ):
404+ return TraceloopSpanKindValues .AGENT
405+
406+ if "agent" in name .lower ():
407+ return TraceloopSpanKindValues .AGENT
408+
409+ class_name = _extract_class_name_from_serialized (serialized )
410+ name_lower = name .lower ()
411+
412+ # Tool detection for RunnableLambda and custom tool chains
413+ if any (keyword in class_name .lower () for keyword in ['tool' ]):
414+ return TraceloopSpanKindValues .TOOL
415+
416+ # More precise tool detection: exclude operation like `parsers`
417+ if any (keyword in name_lower for keyword in ['tool' ]) or (
418+ 'function' in name_lower and 'parser' not in name_lower
419+ ):
420+ return TraceloopSpanKindValues .TOOL
421+
422+ if tags and any ('tool' in tag .lower () for tag in tags ):
423+ return TraceloopSpanKindValues .TOOL
424+
425+ # Retriever detection for RunnableLambda and custom tool chains
426+ if any (keyword in class_name .lower () for keyword in ['retriever' , 'retrieve' , 'vectorstore' ]):
427+ return TraceloopSpanKindValues .RETRIEVER
428+
429+ if any (keyword in name_lower for keyword in ['retriever' , 'retrieve' , 'search' ]):
430+ return TraceloopSpanKindValues .RETRIEVER
431+
432+ # Embedding detection for RunnableLambda and custom chains
433+ if any (keyword in class_name .lower () for keyword in ['embedding' , 'embed' ]):
434+ return TraceloopSpanKindValues .EMBEDDING
435+
436+ if any (keyword in name_lower for keyword in ['embedding' , 'embed' ]):
437+ return TraceloopSpanKindValues .EMBEDDING
438+
439+ if any (keyword in class_name .lower () for keyword in ['rerank' , 'reorder' ]):
440+ return TraceloopSpanKindValues .RERANKER
441+
442+ if any (keyword in name_lower for keyword in ['rerank' , 'reorder' ]):
443+ return TraceloopSpanKindValues .RERANKER
444+
445+ return TraceloopSpanKindValues .TASK
446+
378447 @dont_throw
379448 def on_chain_start (
380449 self ,
@@ -395,12 +464,18 @@ def on_chain_start(
395464 entity_path = ""
396465
397466 name = self ._get_name_from_callback (serialized , ** kwargs )
398- kind = (
467+
468+ base_kind = (
399469 TraceloopSpanKindValues .WORKFLOW
400470 if parent_run_id is None or parent_run_id not in self .spans
401471 else TraceloopSpanKindValues .TASK
402472 )
403473
474+ if base_kind == TraceloopSpanKindValues .TASK :
475+ kind = self ._determine_chain_span_kind (serialized , name , tags )
476+ else :
477+ kind = base_kind
478+
404479 if kind == TraceloopSpanKindValues .WORKFLOW :
405480 workflow_name = name
406481 else :
@@ -710,6 +785,73 @@ def on_tool_end(
710785 )
711786 self ._end_span (span , run_id )
712787
788+ @dont_throw
789+ def on_retriever_start (
790+ self ,
791+ serialized : dict [str , Any ],
792+ query : str ,
793+ * ,
794+ run_id : UUID ,
795+ parent_run_id : Optional [UUID ] = None ,
796+ tags : Optional [list [str ]] = None ,
797+ metadata : Optional [dict [str , Any ]] = None ,
798+ ** kwargs : Any ,
799+ ) -> None :
800+ """Run when retriever starts running."""
801+ if context_api .get_value (_SUPPRESS_INSTRUMENTATION_KEY ):
802+ return
803+
804+ name = self ._get_name_from_callback (serialized , kwargs = kwargs )
805+ workflow_name = self .get_workflow_name (parent_run_id )
806+ entity_path = self .get_entity_path (parent_run_id )
807+
808+ span = self ._create_task_span (
809+ run_id ,
810+ parent_run_id ,
811+ name ,
812+ TraceloopSpanKindValues .RETRIEVER ,
813+ workflow_name ,
814+ name ,
815+ entity_path ,
816+ )
817+ if not should_emit_events () and should_send_prompts ():
818+ span .set_attribute (
819+ SpanAttributes .TRACELOOP_ENTITY_INPUT ,
820+ json .dumps (
821+ {
822+ "query" : query ,
823+ "tags" : tags ,
824+ "metadata" : metadata ,
825+ "kwargs" : kwargs ,
826+ },
827+ cls = CallbackFilteredJSONEncoder ,
828+ ),
829+ )
830+
831+ @dont_throw
832+ def on_retriever_end (
833+ self ,
834+ documents : Any ,
835+ * ,
836+ run_id : UUID ,
837+ parent_run_id : Optional [UUID ] = None ,
838+ ** kwargs : Any ,
839+ ) -> None :
840+ """Run when retriever ends running."""
841+ if context_api .get_value (_SUPPRESS_INSTRUMENTATION_KEY ):
842+ return
843+
844+ span = self ._get_span (run_id )
845+ if not should_emit_events () and should_send_prompts ():
846+ span .set_attribute (
847+ SpanAttributes .TRACELOOP_ENTITY_OUTPUT ,
848+ json .dumps (
849+ {"documents" : str (documents )[:1000 ], "kwargs" : kwargs }, # Limit output size
850+ cls = CallbackFilteredJSONEncoder ,
851+ ),
852+ )
853+ self ._end_span (span , run_id )
854+
713855 def get_parent_span (self , parent_run_id : Optional [str ] = None ):
714856 if parent_run_id is None :
715857 return None
0 commit comments