diff --git a/src/instana/instrumentation/kafka/confluent_kafka_python.py b/src/instana/instrumentation/kafka/confluent_kafka_python.py index 04b1164c..e5d991d2 100644 --- a/src/instana/instrumentation/kafka/confluent_kafka_python.py +++ b/src/instana/instrumentation/kafka/confluent_kafka_python.py @@ -1,19 +1,27 @@ # (c) Copyright IBM Corp. 2025 + try: + import contextvars from typing import Any, Callable, Dict, List, Optional, Tuple import confluent_kafka # noqa: F401 import wrapt from confluent_kafka import Consumer, Producer + from opentelemetry import context, trace from opentelemetry.trace import SpanKind from instana.log import logger from instana.propagators.format import Format + from instana.singletons import get_tracer from instana.util.traceutils import ( get_tracer_tuple, tracing_is_off, ) + from instana.span.span import InstanaSpan + + consumer_token = None + consumer_span = contextvars.ContextVar("confluent_kafka_consumer_span") # As confluent_kafka is a wrapper around the C-developed librdkafka # (provided automatically via binary wheels), we have to create new classes @@ -47,6 +55,9 @@ def poll( ) -> Optional[confluent_kafka.Message]: return super().poll(timeout) + def close(self) -> None: + return super().close() + def trace_kafka_produce( wrapped: Callable[..., InstanaConfluentKafkaProducer.produce], instance: InstanaConfluentKafkaProducer, @@ -105,25 +116,82 @@ def create_span( headers: Optional[List[Tuple[str, bytes]]] = [], exception: Optional[str] = None, ) -> None: - tracer, parent_span, _ = get_tracer_tuple() - parent_context = ( - parent_span.get_span_context() - if parent_span - else tracer.extract( - Format.KAFKA_HEADERS, - headers, - disable_w3c_trace_context=True, + try: + span = consumer_span.get(None) + if span is not None: + close_consumer_span(span) + + tracer, parent_span, _ = get_tracer_tuple() + + if not tracer: + tracer = get_tracer() + is_suppressed = False + + if topic: + is_suppressed = tracer.exporter._HostAgent__is_endpoint_ignored( + "kafka", + span_type, + topic, + ) + + if not is_suppressed and headers: + for header_name, header_value in headers: + if header_name == "x_instana_l_s" and header_value == b"0": + is_suppressed = True + break + + if is_suppressed: + return + + parent_context = ( + parent_span.get_span_context() + if parent_span + else ( + tracer.extract( + Format.KAFKA_HEADERS, + headers, + disable_w3c_trace_context=True, + ) + if tracer.exporter.options.kafka_trace_correlation + else None + ) + ) + span = tracer.start_span( + "kafka-consumer", span_context=parent_context, kind=SpanKind.CONSUMER ) - ) - with tracer.start_as_current_span( - "kafka-consumer", span_context=parent_context, kind=SpanKind.CONSUMER - ) as span: if topic: span.set_attribute("kafka.service", topic) span.set_attribute("kafka.access", span_type) - if exception: span.record_exception(exception) + span.end() + + save_consumer_span_into_context(span) + except Exception as e: + logger.debug( + f"Error while creating kafka-consumer span: {e}" + ) # pragma: no cover + + def save_consumer_span_into_context(span: "InstanaSpan") -> None: + global consumer_token + ctx = trace.set_span_in_context(span) + consumer_token = context.attach(ctx) + consumer_span.set(span) + + def close_consumer_span(span: "InstanaSpan") -> None: + global consumer_token + if span.is_recording(): + span.end() + consumer_span.set(None) + if consumer_token is not None: + context.detach(consumer_token) + consumer_token = None + + def clear_context() -> None: + global consumer_token + context.attach(trace.set_span_in_context(None)) + consumer_token = None + consumer_span.set(None) def trace_kafka_consume( wrapped: Callable[..., InstanaConfluentKafkaConsumer.consume], @@ -131,24 +199,41 @@ def trace_kafka_consume( args: Tuple[int, str, Tuple[Any, ...]], kwargs: Dict[str, Any], ) -> List[confluent_kafka.Message]: - if tracing_is_off(): - return wrapped(*args, **kwargs) - res = None exception = None try: res = wrapped(*args, **kwargs) + for message in res: + create_span("consume", message.topic(), message.headers()) + return res except Exception as exc: exception = exc - finally: - if res: - for message in res: - create_span("consume", message.topic(), message.headers()) - else: - create_span("consume", exception=exception) + create_span("consume", exception=exception) - return res + def trace_kafka_close( + wrapped: Callable[..., InstanaConfluentKafkaConsumer.close], + instance: InstanaConfluentKafkaConsumer, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ) -> None: + try: + # Close any existing consumer span before closing the consumer + span = consumer_span.get(None) + if span is not None: + close_consumer_span(span) + + # Execute the actual close operation + res = wrapped(*args, **kwargs) + + logger.debug("Kafka consumer closed and spans cleaned up") + return res + + except Exception: + # Still try to clean up the span even if close fails + span = consumer_span.get(None) + if span is not None: + close_consumer_span(span) def trace_kafka_poll( wrapped: Callable[..., InstanaConfluentKafkaConsumer.poll], @@ -156,27 +241,20 @@ def trace_kafka_poll( args: Tuple[int, str, Tuple[Any, ...]], kwargs: Dict[str, Any], ) -> Optional[confluent_kafka.Message]: - if tracing_is_off(): - return wrapped(*args, **kwargs) - res = None exception = None try: res = wrapped(*args, **kwargs) + create_span("poll", res.topic(), res.headers()) + return res except Exception as exc: exception = exc - finally: - if res: - create_span("poll", res.topic(), res.headers()) - else: - create_span( - "poll", - next(iter(instance.list_topics().topics)), - exception=exception, - ) - - return res + create_span( + "poll", + next(iter(instance.list_topics().topics)), + exception=exception, + ) # Apply the monkey patch confluent_kafka.Producer = InstanaConfluentKafkaProducer @@ -189,6 +267,9 @@ def trace_kafka_poll( InstanaConfluentKafkaConsumer, "consume", trace_kafka_consume ) wrapt.wrap_function_wrapper(InstanaConfluentKafkaConsumer, "poll", trace_kafka_poll) + wrapt.wrap_function_wrapper( + InstanaConfluentKafkaConsumer, "close", trace_kafka_close + ) logger.debug("Instrumenting Kafka (confluent_kafka)") except ImportError: diff --git a/src/instana/instrumentation/kafka/kafka_python.py b/src/instana/instrumentation/kafka/kafka_python.py index 278390f9..c11e9355 100644 --- a/src/instana/instrumentation/kafka/kafka_python.py +++ b/src/instana/instrumentation/kafka/kafka_python.py @@ -1,23 +1,31 @@ # (c) Copyright IBM Corp. 2025 + try: + import contextvars import inspect from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple import kafka # noqa: F401 import wrapt + from opentelemetry import context, trace from opentelemetry.trace import SpanKind from instana.log import logger from instana.propagators.format import Format + from instana.singletons import get_tracer from instana.util.traceutils import ( get_tracer_tuple, tracing_is_off, ) + from instana.span.span import InstanaSpan if TYPE_CHECKING: from kafka.producer.future import FutureRecordMetadata + consumer_token = None + consumer_span = contextvars.ContextVar("kafka_python_consumer_span") + @wrapt.patch_function_wrapper("kafka", "KafkaProducer.send") def trace_kafka_send( wrapped: Callable[..., "kafka.KafkaProducer.send"], @@ -59,35 +67,86 @@ def trace_kafka_send( kwargs["headers"] = headers try: res = wrapped(*args, **kwargs) + return res except Exception as exc: span.record_exception(exc) - else: - return res def create_span( span_type: str, topic: Optional[str], headers: Optional[List[Tuple[str, bytes]]] = [], - exception: Optional[str] = None, + exception: Optional[Exception] = None, ) -> None: - tracer, parent_span, _ = get_tracer_tuple() - parent_context = ( - parent_span.get_span_context() - if parent_span - else tracer.extract( - Format.KAFKA_HEADERS, - headers, - disable_w3c_trace_context=True, + try: + span = consumer_span.get(None) + if span is not None: + close_consumer_span(span) + + tracer, parent_span, _ = get_tracer_tuple() + + if not tracer: + tracer = get_tracer() + + is_suppressed = False + if topic: + is_suppressed = tracer.exporter._HostAgent__is_endpoint_ignored( + "kafka", + span_type, + topic, + ) + + if not is_suppressed and headers: + for header_name, header_value in headers: + if header_name == "x_instana_l_s" and header_value == b"0": + is_suppressed = True + break + + if is_suppressed: + return + + parent_context = ( + parent_span.get_span_context() + if parent_span + else tracer.extract( + Format.KAFKA_HEADERS, + headers, + disable_w3c_trace_context=True, + ) + ) + span = tracer.start_span( + "kafka-consumer", span_context=parent_context, kind=SpanKind.CONSUMER ) - ) - with tracer.start_as_current_span( - "kafka-consumer", span_context=parent_context, kind=SpanKind.CONSUMER - ) as span: if topic: span.set_attribute("kafka.service", topic) span.set_attribute("kafka.access", span_type) if exception: span.record_exception(exception) + span.end() + + save_consumer_span_into_context(span) + except Exception: + pass + + def save_consumer_span_into_context(span: "InstanaSpan") -> None: + global consumer_token + ctx = trace.set_span_in_context(span) + consumer_token = context.attach(ctx) + consumer_span.set(span) + + def close_consumer_span(span: "InstanaSpan") -> None: + global consumer_token + if span.is_recording(): + span.end() + consumer_span.set(None) + if consumer_token is not None: + context.detach(consumer_token) + consumer_token = None + + def clear_context() -> None: + global consumer_token + context.attach(trace.set_span_in_context(None)) + consumer_token = None + consumer_span.set(None) @wrapt.patch_function_wrapper("kafka", "KafkaConsumer.__next__") def trace_kafka_consume( @@ -96,29 +155,41 @@ def trace_kafka_consume( args: Tuple[int, str, Tuple[Any, ...]], kwargs: Dict[str, Any], ) -> "FutureRecordMetadata": - if tracing_is_off(): - return wrapped(*args, **kwargs) - exception = None res = None try: res = wrapped(*args, **kwargs) + create_span( + "consume", + res.topic if res else list(instance.subscription())[0], + res.headers, + ) + return res + except StopIteration: + pass except Exception as exc: exception = exc - finally: - if res: - create_span( - "consume", - res.topic if res else list(instance.subscription())[0], - res.headers, - ) - else: - create_span( - "consume", list(instance.subscription())[0], exception=exception - ) + create_span( + "consume", list(instance.subscription())[0], exception=exception + ) - return res + @wrapt.patch_function_wrapper("kafka", "KafkaConsumer.close") + def trace_kafka_close( + wrapped: Callable[..., None], + instance: "kafka.KafkaConsumer", + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ) -> None: + try: + span = consumer_span.get(None) + if span is not None: + close_consumer_span(span) + except Exception as e: + logger.debug( + f"Error while closing kafka-consumer span: {e}" + ) # pragma: no cover + return wrapped(*args, **kwargs) @wrapt.patch_function_wrapper("kafka", "KafkaConsumer.poll") def trace_kafka_poll( @@ -127,9 +198,6 @@ def trace_kafka_poll( args: Tuple[int, str, Tuple[Any, ...]], kwargs: Dict[str, Any], ) -> Optional[Dict[str, Any]]: - if tracing_is_off(): - return wrapped(*args, **kwargs) - # The KafkaConsumer.consume() from the kafka-python-ng call the # KafkaConsumer.poll() internally, so we do not consider it here. if any( @@ -143,23 +211,17 @@ def trace_kafka_poll( try: res = wrapped(*args, **kwargs) + for partition, consumer_records in res.items(): + for message in consumer_records: + create_span( + "poll", + partition.topic, + message.headers if hasattr(message, "headers") else [], + ) + return res except Exception as exc: exception = exc - finally: - if res: - for partition, consumer_records in res.items(): - for message in consumer_records: - create_span( - "poll", - partition.topic, - message.headers if hasattr(message, "headers") else [], - ) - else: - create_span( - "poll", list(instance.subscription())[0], exception=exception - ) - - return res + create_span("poll", list(instance.subscription())[0], exception=exception) logger.debug("Instrumenting Kafka (kafka-python)") except ImportError: diff --git a/tests/clients/kafka/test_confluent_kafka.py b/tests/clients/kafka/test_confluent_kafka.py index fb9ab4c8..61f31bce 100644 --- a/tests/clients/kafka/test_confluent_kafka.py +++ b/tests/clients/kafka/test_confluent_kafka.py @@ -11,7 +11,7 @@ Producer, ) from confluent_kafka.admin import AdminClient, NewTopic -from mock import patch +from mock import patch, Mock from opentelemetry.trace import SpanKind from opentelemetry.trace.span import format_span_id @@ -20,6 +20,15 @@ from instana.singletons import agent, tracer from instana.util.config import parse_ignored_endpoints_from_yaml from tests.helpers import get_first_span_by_filter, testenv +from instana.instrumentation.kafka import confluent_kafka_python +from instana.instrumentation.kafka.confluent_kafka_python import ( + clear_context, + save_consumer_span_into_context, + close_consumer_span, + trace_kafka_close, + consumer_span, +) +from instana.span.span import InstanaSpan class TestConfluentKafka: @@ -68,8 +77,12 @@ def _resource(self) -> Generator[None, None, None]: agent.options = StandardOptions() yield # teardown - # Ensure that allow_exit_as_root has the default value""" - agent.options.allow_exit_as_root = False + # Clear spans before resetting options + self.recorder.clear_spans() + + # Clear context + clear_context() + # Close connections self.kafka_client.delete_topics( [ @@ -129,24 +142,6 @@ def test_trace_confluent_kafka_consume(self) -> None: spans = self.recorder.queued_spans() assert len(spans) == 2 - kafka_span = spans[0] - test_span = spans[1] - - # Same traceId - assert test_span.t == kafka_span.t - - # Parent relationships - assert kafka_span.p == test_span.s - - # Error logging - assert not test_span.ec - assert not kafka_span.ec - - assert kafka_span.n == "kafka" - assert kafka_span.k == SpanKind.SERVER - assert kafka_span.data["kafka"]["service"] == testenv["kafka_topic"] - assert kafka_span.data["kafka"]["access"] == "consume" - def test_trace_confluent_kafka_poll(self) -> None: # Produce some events self.producer.produce(testenv["kafka_topic"], b"raw_bytes1") @@ -162,15 +157,22 @@ def test_trace_confluent_kafka_poll(self) -> None: consumer.subscribe([testenv["kafka_topic"]]) with tracer.start_as_current_span("test"): - msg = consumer.poll(timeout=30) # noqa: F841 + msg = consumer.poll(timeout=3) # noqa: F841 consumer.close() spans = self.recorder.queued_spans() assert len(spans) == 2 - kafka_span = spans[0] - test_span = spans[1] + def filter(span): + return span.n == "kafka" and span.data["kafka"]["access"] == "poll" + + kafka_span = get_first_span_by_filter(spans, filter) + + def filter(span): + return span.n == "sdk" and span.data["sdk"]["name"] == "test" + + test_span = get_first_span_by_filter(spans, filter) # Same traceId assert test_span.t == kafka_span.t @@ -282,10 +284,7 @@ def test_ignore_confluent_kafka_consumer(self) -> None: consumer.close() spans = self.recorder.queued_spans() - assert len(spans) == 3 - - filtered_spans = agent.filter_spans(spans) - assert len(filtered_spans) == 1 + assert len(spans) == 1 @patch.dict( os.environ, @@ -323,7 +322,7 @@ def test_ignore_confluent_specific_topic(self) -> None: consumer.close() spans = self.recorder.queued_spans() - assert len(spans) == 5 + assert len(spans) == 4 filtered_spans = agent.filter_spans(spans) assert len(filtered_spans) == 3 @@ -362,7 +361,7 @@ def test_ignore_confluent_specific_topic_with_config_file(self) -> None: consumer.close() spans = self.recorder.queued_spans() - assert len(spans) == 3 + assert len(spans) == 2 filtered_spans = agent.filter_spans(spans) assert len(filtered_spans) == 1 @@ -482,7 +481,7 @@ def test_confluent_kafka_poll_root_exit_without_trace_correlation(self) -> None: agent.options.kafka_trace_correlation = False # Produce some events - self.producer.produce(testenv["kafka_topic"], b"raw_bytes1") + self.producer.produce(f'{testenv["kafka_topic"]}-wo-tc', b"raw_bytes1") self.producer.flush() # Consume the events @@ -491,7 +490,7 @@ def test_confluent_kafka_poll_root_exit_without_trace_correlation(self) -> None: consumer_config["auto.offset.reset"] = "earliest" consumer = Consumer(consumer_config) - consumer.subscribe([testenv["kafka_topic"]]) + consumer.subscribe([f'{testenv["kafka_topic"]}-wo-tc']) msg = consumer.poll(timeout=30) # noqa: F841 @@ -504,14 +503,14 @@ def test_confluent_kafka_poll_root_exit_without_trace_correlation(self) -> None: spans, lambda span: span.n == "kafka" and span.data["kafka"]["access"] == "produce" - and span.data["kafka"]["service"] == "span-topic", + and span.data["kafka"]["service"] == f'{testenv["kafka_topic"]}-wo-tc', ) poll_span = get_first_span_by_filter( spans, lambda span: span.n == "kafka" and span.data["kafka"]["access"] == "poll" - and span.data["kafka"]["service"] == "span-topic", + and span.data["kafka"]["service"] == f'{testenv["kafka_topic"]}-wo-tc', ) # Different traceId @@ -598,7 +597,7 @@ def test_confluent_kafka_downstream_suppression(self) -> None: consumer.close() spans = self.recorder.queued_spans() - assert len(spans) == 3 + assert len(spans) == 2 producer_span_1 = get_first_span_by_filter( spans, @@ -628,10 +627,7 @@ def test_confluent_kafka_downstream_suppression(self) -> None: assert producer_span_1 # consumer has been suppressed assert not consumer_span_1 - - assert producer_span_2.t == consumer_span_2.t - assert producer_span_2.s == consumer_span_2.p - assert producer_span_2.s != consumer_span_2.s + assert not consumer_span_2 for message in messages: if message.topic() == "span-topic_1": @@ -649,3 +645,75 @@ def test_confluent_kafka_downstream_suppression(self) -> None: testenv["kafka_topic"] + "_2", ] ) + + def test_save_consumer_span_into_context(self, span: "InstanaSpan") -> None: + """Test save_consumer_span_into_context function.""" + # Verify initial state + assert consumer_span.get(None) is None + assert confluent_kafka_python.consumer_token is None + + # Save span into context + save_consumer_span_into_context(span) + + # Verify token is stored + assert confluent_kafka_python.consumer_token is not None + + def test_close_consumer_span_recording_span(self, span: "InstanaSpan") -> None: + """Test close_consumer_span with a recording span.""" + # Save span into context first + save_consumer_span_into_context(span) + assert confluent_kafka_python.consumer_token is not None + + # Verify span is recording + assert span.is_recording() + + # Close the span + close_consumer_span(span) + + # Verify span was ended and context cleared + assert not span.is_recording() + assert consumer_span.get(None) is None + assert confluent_kafka_python.consumer_token is None + + def test_clear_context(self, span: "InstanaSpan") -> None: + """Test clear_context function.""" + # Save span into context + save_consumer_span_into_context(span) + + # Verify context has data + assert consumer_span.get(None) == span + assert confluent_kafka_python.consumer_token is not None + + # Clear context + clear_context() + + # Verify all context is cleared + assert consumer_span.get(None) is None + assert confluent_kafka_python.consumer_token is None + + def test_trace_kafka_close_exception_handling(self, span: "InstanaSpan") -> None: + """Test trace_kafka_close handles exceptions and still cleans up spans.""" + # Save span into context + save_consumer_span_into_context(span) + + # Verify span is in context + assert consumer_span.get(None) == span + assert confluent_kafka_python.consumer_token is not None + + # Mock a wrapped function that raises an exception + mock_wrapped = Mock(side_effect=Exception("Close operation failed")) + mock_instance = Mock() + + # Call trace_kafka_close - it should handle the exception gracefully + # and still clean up the span + trace_kafka_close(mock_wrapped, mock_instance, (), {}) + + # Verify the wrapped function was called + mock_wrapped.assert_called_once_with() + + # Verify that despite the exception, the span was cleaned up + assert consumer_span.get(None) is None + assert confluent_kafka_python.consumer_token is None + + # Verify span was ended + assert not span.is_recording() diff --git a/tests/clients/kafka/test_kafka_python.py b/tests/clients/kafka/test_kafka_python.py index dd568583..eb3723e3 100644 --- a/tests/clients/kafka/test_kafka_python.py +++ b/tests/clients/kafka/test_kafka_python.py @@ -17,6 +17,15 @@ from instana.util.config import parse_ignored_endpoints_from_yaml from tests.helpers import get_first_span_by_filter, testenv +from instana.instrumentation.kafka import kafka_python +from instana.instrumentation.kafka.kafka_python import ( + clear_context, + save_consumer_span_into_context, + close_consumer_span, + consumer_span, +) +from instana.span.span import InstanaSpan + class TestKafkaPython: @pytest.fixture(autouse=True) @@ -72,6 +81,10 @@ def _resource(self) -> Generator[None, None, None]: agent.options.allow_exit_as_root = False # Close connections self.producer.close() + + # Clear context + clear_context() + self.kafka_client.delete_topics( [ testenv["kafka_topic"], @@ -132,10 +145,17 @@ def test_trace_kafka_python_consume(self) -> None: consumer.close() spans = self.recorder.queued_spans() - assert len(spans) == 4 + assert len(spans) == 3 - kafka_span = spans[0] - test_span = spans[len(spans) - 1] + def filter(span): + return span.n == "kafka" and span.data["kafka"]["access"] == "consume" + + kafka_span = get_first_span_by_filter(spans, filter) + + def filter(span): + return span.n == "sdk" and span.data["sdk"]["name"] == "test" + + test_span = get_first_span_by_filter(spans, filter) # Same traceId assert test_span.t == kafka_span.t @@ -168,15 +188,22 @@ def test_trace_kafka_python_poll(self) -> None: ) with tracer.start_as_current_span("test"): - msg = consumer.poll() # noqa: F841 + msg = consumer.poll(timeout_ms=3000) # noqa: F841 consumer.close() spans = self.recorder.queued_spans() - assert len(spans) == 2 + assert len(spans) == 3 - kafka_span = spans[0] - test_span = spans[1] + def filter(span): + return span.n == "kafka" and span.data["kafka"]["access"] == "poll" + + kafka_span = get_first_span_by_filter(spans, filter) + + def filter(span): + return span.n == "sdk" and span.data["sdk"]["name"] == "test" + + test_span = get_first_span_by_filter(spans, filter) # Same traceId assert test_span.t == kafka_span.t @@ -194,27 +221,36 @@ def test_trace_kafka_python_poll(self) -> None: assert kafka_span.data["kafka"]["access"] == "poll" def test_trace_kafka_python_error(self) -> None: - # Consume the events consumer = KafkaConsumer( "inexistent_kafka_topic", bootstrap_servers=testenv["kafka_bootstrap_servers"], - auto_offset_reset="earliest", # consume earliest available messages - enable_auto_commit=False, # do not auto-commit offsets + auto_offset_reset="earliest", + enable_auto_commit=False, consumer_timeout_ms=1000, ) with tracer.start_as_current_span("test"): - for msg in consumer: - if msg is None: - break + consumer._client = None - consumer.close() + try: + for msg in consumer: + if msg is None: + break + except Exception: + pass spans = self.recorder.queued_spans() assert len(spans) == 2 - kafka_span = spans[0] - test_span = spans[1] + def filter(span): + return span.n == "kafka" and span.data["kafka"]["access"] == "consume" + + kafka_span = get_first_span_by_filter(spans, filter) + + def filter(span): + return span.n == "sdk" and span.data["sdk"]["name"] == "test" + + test_span = get_first_span_by_filter(spans, filter) # Same traceId assert test_span.t == kafka_span.t @@ -230,7 +266,10 @@ def test_trace_kafka_python_error(self) -> None: assert kafka_span.k == SpanKind.SERVER assert kafka_span.data["kafka"]["service"] == "inexistent_kafka_topic" assert kafka_span.data["kafka"]["access"] == "consume" - assert kafka_span.data["kafka"]["error"] == "StopIteration()" + assert ( + kafka_span.data["kafka"]["error"] + == "'NoneType' object has no attribute 'poll'" + ) def consume_from_topic(self, topic_name: str) -> None: consumer = KafkaConsumer( @@ -302,10 +341,7 @@ def test_ignore_kafka_consumer(self) -> None: self.consume_from_topic(testenv["kafka_topic"]) spans = self.recorder.queued_spans() - assert len(spans) == 4 - - filtered_spans = agent.filter_spans(spans) - assert len(filtered_spans) == 1 + assert len(spans) == 1 @patch.dict( os.environ, @@ -326,10 +362,10 @@ def test_ignore_specific_topic(self) -> None: self.consume_from_topic(testenv["kafka_topic"] + "_1") spans = self.recorder.queued_spans() - assert len(spans) == 11 + assert len(spans) == 7 filtered_spans = agent.filter_spans(spans) - assert len(filtered_spans) == 8 + assert len(filtered_spans) == 6 span_to_be_filtered = get_first_span_by_filter( spans, @@ -351,10 +387,7 @@ def test_ignore_specific_topic_with_config_file(self) -> None: self.consume_from_topic(testenv["kafka_topic"]) spans = self.recorder.queued_spans() - assert len(spans) == 3 - - filtered_spans = agent.filter_spans(spans) - assert len(filtered_spans) == 1 + assert len(spans) == 1 def test_kafka_consumer_root_exit(self) -> None: agent.options.allow_exit_as_root = True @@ -378,7 +411,7 @@ def test_kafka_consumer_root_exit(self) -> None: consumer.close() spans = self.recorder.queued_spans() - assert len(spans) == 4 + assert len(spans) == 3 producer_span = spans[0] consumer_span = spans[1] @@ -713,3 +746,50 @@ def test_kafka_downstream_suppression(self) -> None: format_span_id(producer_span_2.s).encode("utf-8"), ), ] + + def test_save_consumer_span_into_context(self, span: "InstanaSpan") -> None: + """Test save_consumer_span_into_context function.""" + # Verify initial state + assert consumer_span.get(None) is None + assert kafka_python.consumer_token is None + + # Save span into context + save_consumer_span_into_context(span) + + # Verify span is saved in context variable + assert consumer_span.get(None) == span + # Verify token is stored + assert kafka_python.consumer_token is not None + + def test_close_consumer_span_recording_span(self, span: "InstanaSpan") -> None: + """Test close_consumer_span with a recording span.""" + # Save span into context first + save_consumer_span_into_context(span) + assert kafka_python.consumer_token is not None + + # Verify span is recording + assert span.is_recording() + + # Close the span + close_consumer_span(span) + + # Verify span was ended and context cleared + assert not span.is_recording() + assert consumer_span.get(None) is None + assert kafka_python.consumer_token is None + + def test_clear_context(self, span: "InstanaSpan") -> None: + """Test clear_context function.""" + # Save span into context + save_consumer_span_into_context(span) + + # Verify context has data + assert consumer_span.get(None) == span + assert kafka_python.consumer_token is not None + + # Clear context + clear_context() + + # Verify all context is cleared + assert consumer_span.get(None) is None + assert kafka_python.consumer_token is None