@@ -359,6 +359,9 @@ def _create_llm_span(
359359 _set_span_attribute (span , SpanAttributes .LLM_SYSTEM , vendor )
360360 _set_span_attribute (span , SpanAttributes .LLM_REQUEST_TYPE , request_type .value )
361361
362+ span_kind = self ._determine_llm_span_kind (serialized )
363+ _set_span_attribute (span , SpanAttributes .TRACELOOP_SPAN_KIND , span_kind .value )
364+
362365 # we already have an LLM span by this point,
363366 # so skip any downstream instrumentation from here
364367 try :
@@ -375,6 +378,69 @@ def _create_llm_span(
375378
376379 return span
377380
381+ def _determine_llm_span_kind (self , serialized : Optional [dict [str , Any ]]) -> TraceloopSpanKindValues :
382+ """Determine the appropriate span kind for LLM operations based on model type."""
383+ if not serialized :
384+ return TraceloopSpanKindValues .GENERATION
385+
386+ class_name = _extract_class_name_from_serialized (serialized )
387+ class_name_lower = class_name .lower ()
388+
389+ if any (keyword in class_name_lower for keyword in ['embedding' , 'embed' ]):
390+ return TraceloopSpanKindValues .EMBEDDING
391+
392+ # Default to generation for other LLM operations
393+ return TraceloopSpanKindValues .GENERATION
394+
395+ def _determine_chain_span_kind (
396+ self ,
397+ serialized : dict [str , Any ],
398+ name : str ,
399+ tags : Optional [list [str ]] = None
400+ ) -> TraceloopSpanKindValues :
401+ if serialized and "id" in serialized :
402+ class_path = serialized ["id" ]
403+ if any ("agent" in part .lower () for part in class_path ):
404+ return TraceloopSpanKindValues .AGENT
405+
406+ if "agent" in name .lower ():
407+ return TraceloopSpanKindValues .AGENT
408+
409+ class_name = _extract_class_name_from_serialized (serialized )
410+ name_lower = name .lower ()
411+
412+ # Tool detection for RunnableLambda and custom tool chains
413+ if any (keyword in class_name .lower () for keyword in ['tool' ]):
414+ return TraceloopSpanKindValues .TOOL
415+
416+ if any (keyword in name_lower for keyword in ['tool' , 'function' ]):
417+ return TraceloopSpanKindValues .TOOL
418+
419+ if tags and any ('tool' in tag .lower () for tag in tags ):
420+ return TraceloopSpanKindValues .TOOL
421+
422+ # Retriever detection for RunnableLambda and custom tool chains
423+ if any (keyword in class_name .lower () for keyword in ['retriever' , 'retrieve' , 'vectorstore' ]):
424+ return TraceloopSpanKindValues .RETRIEVER
425+
426+ if any (keyword in name_lower for keyword in ['retriever' , 'retrieve' , 'search' ]):
427+ return TraceloopSpanKindValues .RETRIEVER
428+
429+ # Embedding detection for RunnableLambda and custom chains
430+ if any (keyword in class_name .lower () for keyword in ['embedding' , 'embed' ]):
431+ return TraceloopSpanKindValues .EMBEDDING
432+
433+ if any (keyword in name_lower for keyword in ['embedding' , 'embed' ]):
434+ return TraceloopSpanKindValues .EMBEDDING
435+
436+ if any (keyword in class_name .lower () for keyword in ['rerank' , 'reorder' ]):
437+ return TraceloopSpanKindValues .RERANKER
438+
439+ if any (keyword in name_lower for keyword in ['rerank' , 'reorder' ]):
440+ return TraceloopSpanKindValues .RERANKER
441+
442+ return TraceloopSpanKindValues .TASK
443+
378444 @dont_throw
379445 def on_chain_start (
380446 self ,
@@ -395,12 +461,18 @@ def on_chain_start(
395461 entity_path = ""
396462
397463 name = self ._get_name_from_callback (serialized , ** kwargs )
398- kind = (
464+
465+ base_kind = (
399466 TraceloopSpanKindValues .WORKFLOW
400467 if parent_run_id is None or parent_run_id not in self .spans
401468 else TraceloopSpanKindValues .TASK
402469 )
403470
471+ if base_kind == TraceloopSpanKindValues .TASK :
472+ kind = self ._determine_chain_span_kind (serialized , name , tags )
473+ else :
474+ kind = base_kind
475+
404476 if kind == TraceloopSpanKindValues .WORKFLOW :
405477 workflow_name = name
406478 else :
@@ -710,6 +782,73 @@ def on_tool_end(
710782 )
711783 self ._end_span (span , run_id )
712784
785+ @dont_throw
786+ def on_retriever_start (
787+ self ,
788+ serialized : dict [str , Any ],
789+ query : str ,
790+ * ,
791+ run_id : UUID ,
792+ parent_run_id : Optional [UUID ] = None ,
793+ tags : Optional [list [str ]] = None ,
794+ metadata : Optional [dict [str , Any ]] = None ,
795+ ** kwargs : Any ,
796+ ) -> None :
797+ """Run when retriever starts running."""
798+ if context_api .get_value (_SUPPRESS_INSTRUMENTATION_KEY ):
799+ return
800+
801+ name = self ._get_name_from_callback (serialized , kwargs = kwargs )
802+ workflow_name = self .get_workflow_name (parent_run_id )
803+ entity_path = self .get_entity_path (parent_run_id )
804+
805+ span = self ._create_task_span (
806+ run_id ,
807+ parent_run_id ,
808+ name ,
809+ TraceloopSpanKindValues .RETRIEVER ,
810+ workflow_name ,
811+ name ,
812+ entity_path ,
813+ )
814+ if not should_emit_events () and should_send_prompts ():
815+ span .set_attribute (
816+ SpanAttributes .TRACELOOP_ENTITY_INPUT ,
817+ json .dumps (
818+ {
819+ "query" : query ,
820+ "tags" : tags ,
821+ "metadata" : metadata ,
822+ "kwargs" : kwargs ,
823+ },
824+ cls = CallbackFilteredJSONEncoder ,
825+ ),
826+ )
827+
828+ @dont_throw
829+ def on_retriever_end (
830+ self ,
831+ documents : Any ,
832+ * ,
833+ run_id : UUID ,
834+ parent_run_id : Optional [UUID ] = None ,
835+ ** kwargs : Any ,
836+ ) -> None :
837+ """Run when retriever ends running."""
838+ if context_api .get_value (_SUPPRESS_INSTRUMENTATION_KEY ):
839+ return
840+
841+ span = self ._get_span (run_id )
842+ if not should_emit_events () and should_send_prompts ():
843+ span .set_attribute (
844+ SpanAttributes .TRACELOOP_ENTITY_OUTPUT ,
845+ json .dumps (
846+ {"documents" : str (documents )[:1000 ], "kwargs" : kwargs }, # Limit output size
847+ cls = CallbackFilteredJSONEncoder ,
848+ ),
849+ )
850+ self ._end_span (span , run_id )
851+
713852 def get_parent_span (self , parent_run_id : Optional [str ] = None ):
714853 if parent_run_id is None :
715854 return None
0 commit comments