Skip to content

KAFKA-16505: Add source raw key and value #18739

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,40 @@
*/
package org.apache.kafka.streams.integration;

import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.common.MetricName;
import org.apache.kafka.common.serialization.Serdes;
import org.apache.kafka.common.serialization.StringSerializer;
import org.apache.kafka.common.utils.Bytes;
import org.apache.kafka.streams.KeyValue;
import org.apache.kafka.streams.KeyValueTimestamp;
import org.apache.kafka.streams.StreamsBuilder;
import org.apache.kafka.streams.StreamsConfig;
import org.apache.kafka.streams.TestInputTopic;
import org.apache.kafka.streams.Topology;
import org.apache.kafka.streams.TopologyTestDriver;
import org.apache.kafka.streams.errors.ErrorHandlerContext;
import org.apache.kafka.streams.errors.LogAndContinueProcessingExceptionHandler;
import org.apache.kafka.streams.errors.LogAndFailProcessingExceptionHandler;
import org.apache.kafka.streams.errors.ProcessingExceptionHandler;
import org.apache.kafka.streams.errors.StreamsException;
import org.apache.kafka.streams.kstream.Consumed;
import org.apache.kafka.streams.kstream.Grouped;
import org.apache.kafka.streams.kstream.JoinWindows;
import org.apache.kafka.streams.kstream.Materialized;
import org.apache.kafka.streams.kstream.StreamJoined;
import org.apache.kafka.streams.processor.api.ContextualProcessor;
import org.apache.kafka.streams.processor.api.ProcessorSupplier;
import org.apache.kafka.streams.processor.api.Record;
import org.apache.kafka.streams.state.KeyValueStore;
import org.apache.kafka.test.MockProcessorSupplier;

import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import java.time.Duration;
import java.time.Instant;
Expand All @@ -48,6 +59,7 @@
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Stream;

import static org.apache.kafka.common.utils.Utils.mkEntry;
import static org.apache.kafka.common.utils.Utils.mkMap;
Expand Down Expand Up @@ -385,6 +397,131 @@ public void shouldStopProcessingWhenFatalUserExceptionProcessingExceptionHandler
}
}

static Stream<Arguments> sourceRawRecordTopologyTestCases() {
// Validate source raw key and source raw value for fully stateless topology
final List<ProducerRecord<String, String>> statelessTopologyEvent = List.of(new ProducerRecord<>("TOPIC_NAME", "ID123-1", "ID123-A1"));
final StreamsBuilder statelessTopologyBuilder = new StreamsBuilder();
statelessTopologyBuilder
.stream("TOPIC_NAME", Consumed.with(Serdes.String(), Serdes.String()))
.selectKey((key, value) -> "newKey")
.mapValues(value -> {
throw new RuntimeException("Error");
});

// Validate source raw key and source raw value for processing exception in aggregator with caching enabled
final List<ProducerRecord<String, String>> cacheAggregateExceptionInAggregatorEvent = List.of(new ProducerRecord<>("TOPIC_NAME", "INITIAL-KEY123-1", "ID123-A1"));
final StreamsBuilder cacheAggregateExceptionInAggregatorTopologyBuilder = new StreamsBuilder();
cacheAggregateExceptionInAggregatorTopologyBuilder
.stream("TOPIC_NAME", Consumed.with(Serdes.String(), Serdes.String()))
.groupBy((key, value) -> "ID123-1", Grouped.with(Serdes.String(), Serdes.String()))
.aggregate(() -> "initialValue",
(key, value, aggregate) -> {
throw new RuntimeException("Error");
},
Materialized.<String, String, KeyValueStore<Bytes, byte[]>>as("aggregate")
.withKeySerde(Serdes.String())
.withValueSerde(Serdes.String())
.withCachingEnabled());

// Validate source raw key and source raw value for processing exception after aggregation with caching enabled
final List<ProducerRecord<String, String>> cacheAggregateExceptionAfterAggregationEvent = List.of(new ProducerRecord<>("TOPIC_NAME", "INITIAL-KEY123-1", "ID123-A1"));
final StreamsBuilder cacheAggregateExceptionAfterAggregationTopologyBuilder = new StreamsBuilder();
cacheAggregateExceptionAfterAggregationTopologyBuilder
.stream("TOPIC_NAME", Consumed.with(Serdes.String(), Serdes.String()))
.groupBy((key, value) -> "ID123-1", Grouped.with(Serdes.String(), Serdes.String()))
.aggregate(() -> "initialValue",
(key, value, aggregate) -> value,
Materialized.<String, String, KeyValueStore<Bytes, byte[]>>as("aggregate")
.withKeySerde(Serdes.String())
.withValueSerde(Serdes.String())
.withCachingEnabled())
.mapValues(value -> {
throw new RuntimeException("Error");
});

// Validate source raw key and source raw value for processing exception after aggregation with caching disabled
final List<ProducerRecord<String, String>> noCacheAggregateExceptionAfterAggregationEvents = List.of(new ProducerRecord<>("TOPIC_NAME", "INITIAL-KEY123-1", "ID123-A1"));
final StreamsBuilder noCacheAggregateExceptionAfterAggregationTopologyBuilder = new StreamsBuilder();
noCacheAggregateExceptionAfterAggregationTopologyBuilder
.stream("TOPIC_NAME", Consumed.with(Serdes.String(), Serdes.String()))
.groupBy((key, value) -> "ID123-1", Grouped.with(Serdes.String(), Serdes.String()))
.aggregate(() -> "initialValue",
(key, value, aggregate) -> value,
Materialized.<String, String, KeyValueStore<Bytes, byte[]>>as("aggregate")
.withKeySerde(Serdes.String())
.withValueSerde(Serdes.String())
.withCachingDisabled())
.mapValues(value -> {
throw new RuntimeException("Error");
});

// Validate source raw key and source raw value for processing exception after table creation with caching enabled
final List<ProducerRecord<String, String>> cacheTableEvents = List.of(new ProducerRecord<>("TOPIC_NAME", "ID123-1", "ID123-A1"));
final StreamsBuilder cacheTableTopologyBuilder = new StreamsBuilder();
cacheTableTopologyBuilder
.table("TOPIC_NAME", Consumed.with(Serdes.String(), Serdes.String()),
Materialized.<String, String, KeyValueStore<Bytes, byte[]>>as("table")
.withKeySerde(Serdes.String())
.withValueSerde(Serdes.String())
.withCachingEnabled())
.mapValues(value -> {
throw new RuntimeException("Error");
});

// Validate source raw key and source raw value for processing exception in join
final List<ProducerRecord<String, String>> joinEvents = List.of(
new ProducerRecord<>("TOPIC_NAME_2", "INITIAL-KEY123-1", "ID123-A1"),
new ProducerRecord<>("TOPIC_NAME", "INITIAL-KEY123-2", "ID123-A1")
);
final StreamsBuilder joinTopologyBuilder = new StreamsBuilder();
joinTopologyBuilder
.stream("TOPIC_NAME", Consumed.with(Serdes.String(), Serdes.String()))
.selectKey((key, value) -> "ID123-1")
.leftJoin(joinTopologyBuilder.stream("TOPIC_NAME_2", Consumed.with(Serdes.String(), Serdes.String()))
.selectKey((key, value) -> "ID123-1"),
(key, left, right) -> {
throw new RuntimeException("Error");
},
JoinWindows.ofTimeDifferenceAndGrace(Duration.ofMinutes(5), Duration.ofMinutes(1)),
StreamJoined.with(
Serdes.String(), Serdes.String(), Serdes.String())
.withName("join-rekey")
.withStoreName("join-store"));

return Stream.of(
Arguments.of(statelessTopologyEvent, statelessTopologyBuilder.build()),
Arguments.of(cacheAggregateExceptionInAggregatorEvent, cacheAggregateExceptionInAggregatorTopologyBuilder.build()),
Arguments.of(cacheAggregateExceptionAfterAggregationEvent, noCacheAggregateExceptionAfterAggregationTopologyBuilder.build()),
Arguments.of(noCacheAggregateExceptionAfterAggregationEvents, cacheAggregateExceptionInAggregatorTopologyBuilder.build()),
Arguments.of(cacheTableEvents, cacheTableTopologyBuilder.build()),
Arguments.of(joinEvents, joinTopologyBuilder.build())
);
}

@ParameterizedTest
@MethodSource("sourceRawRecordTopologyTestCases")
public void shouldVerifySourceRawKeyAndSourceRawValuePresentOrNotInErrorHandlerContext(final List<ProducerRecord<String, String>> events,
final Topology topology) {
final Properties properties = new Properties();
properties.put(StreamsConfig.PROCESSING_EXCEPTION_HANDLER_CLASS_CONFIG,
AssertSourceRawRecordProcessingExceptionHandlerMockTest.class);

try (final TopologyTestDriver driver = new TopologyTestDriver(topology, properties, Instant.ofEpochMilli(0L))) {
for (final ProducerRecord<String, String> event : events) {
final TestInputTopic<String, String> inputTopic = driver.createInputTopic(event.topic(), new StringSerializer(), new StringSerializer());

final String key = event.key();
final String value = event.value();

if (event.topic().equals("TOPIC_NAME")) {
assertThrows(StreamsException.class, () -> inputTopic.pipeInput(key, value, TIMESTAMP));
} else {
inputTopic.pipeInput(event.key(), event.value(), TIMESTAMP);
}
}
}
}

public static class ContinueProcessingExceptionHandlerMockTest implements ProcessingExceptionHandler {
@Override
public ProcessingExceptionHandler.ProcessingHandlerResponse handle(final ErrorHandlerContext context, final Record<?, ?> record, final Exception exception) {
Expand Down Expand Up @@ -422,10 +559,28 @@ private static void assertProcessingExceptionHandlerInputs(final ErrorHandlerCon
assertTrue(Arrays.asList("ID123-A2", "ID123-A5").contains((String) record.value()));
assertEquals("TOPIC_NAME", context.topic());
assertEquals("KSTREAM-PROCESSOR-0000000003", context.processorNodeId());
assertTrue(Arrays.equals("ID123-2-ERR".getBytes(), context.sourceRawKey())
|| Arrays.equals("ID123-5-ERR".getBytes(), context.sourceRawKey()));
assertTrue(Arrays.equals("ID123-A2".getBytes(), context.sourceRawValue())
|| Arrays.equals("ID123-A5".getBytes(), context.sourceRawValue()));
assertEquals(TIMESTAMP.toEpochMilli(), context.timestamp());
assertTrue(exception.getMessage().contains("Exception should be handled by processing exception handler"));
}

public static class AssertSourceRawRecordProcessingExceptionHandlerMockTest implements ProcessingExceptionHandler {
@Override
public ProcessingExceptionHandler.ProcessingHandlerResponse handle(final ErrorHandlerContext context, final Record<?, ?> record, final Exception exception) {
assertEquals("ID123-1", Serdes.String().deserializer().deserialize("topic", context.sourceRawKey()));
assertEquals("ID123-A1", Serdes.String().deserializer().deserialize("topic", context.sourceRawValue()));
return ProcessingExceptionHandler.ProcessingHandlerResponse.FAIL;
}

@Override
public void configure(final Map<String, ?> configs) {
// No-op
}
}

/**
* Metric name for dropped records total.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,38 @@ public interface ErrorHandlerContext {
* @return The timestamp.
*/
long timestamp();

/**
* Return the non-deserialized byte[] of the input message key if the context has been triggered by a message.
*
* <p> If this method is invoked within a {@link Punctuator#punctuate(long)
* punctuation callback}, or while processing a record that was forwarded by a punctuation
* callback, it will return null.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
* callback, it will return null.
* callback, it will return {@code null}.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

*
* <p> If this method is invoked in a sub-topology due to a repartition, the returned key would be one sent
* to the repartition topic.
*
* <p> Always returns null if this method is invoked within a
* ProductionExceptionHandler.handle(ErrorHandlerContext, ProducerRecord, Exception)
*
* @return the raw byte of the key of the source message
*/
byte[] sourceRawKey();

/**
* Return the non-deserialized byte[] of the input message value if the context has been triggered by a message.
*
* <p> If this method is invoked within a {@link Punctuator#punctuate(long)
* punctuation callback}, or while processing a record that was forwarded by a punctuation
* callback, it will return {@code null}.
*
* <p> If this method is invoked in a sub-topology due to a repartition, the returned key would be one sent
* to the repartition topic.
*
* <p> Always returns null if this method is invoked within a
* ProductionExceptionHandler.handle(ErrorHandlerContext, ProducerRecord, Exception)
*
* @return the raw byte of the value of the source message
*/
byte[] sourceRawValue();
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ public class DefaultErrorHandlerContext implements ErrorHandlerContext {
private final Headers headers;
private final String processorNodeId;
private final TaskId taskId;
private final byte[] sourceRawKey;
private final byte[] sourceRawValue;

private final long timestamp;
private final ProcessorContext processorContext;
Expand All @@ -44,7 +46,9 @@ public DefaultErrorHandlerContext(final ProcessorContext processorContext,
final Headers headers,
final String processorNodeId,
final TaskId taskId,
final long timestamp) {
final long timestamp,
final byte[] sourceRawKey,
final byte[] sourceRawValue) {
this.topic = topic;
this.partition = partition;
this.offset = offset;
Expand All @@ -53,6 +57,8 @@ public DefaultErrorHandlerContext(final ProcessorContext processorContext,
this.taskId = taskId;
this.processorContext = processorContext;
this.timestamp = timestamp;
this.sourceRawKey = sourceRawKey;
this.sourceRawValue = sourceRawValue;
}

@Override
Expand Down Expand Up @@ -90,6 +96,14 @@ public long timestamp() {
return timestamp;
}

public byte[] sourceRawKey() {
return sourceRawKey;
}

public byte[] sourceRawValue() {
return sourceRawValue;
}

@Override
public String toString() {
// we do exclude headers on purpose, to not accidentally log user data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,31 @@ public interface RecordContext {
*/
Headers headers();

/**
* Return the non-deserialized byte[] of the input message key if the context has been triggered by a message.
*
* <p> If this method is invoked within a {@link Punctuator#punctuate(long)
* punctuation callback}, or while processing a record that was forwarded by a punctuation
* callback, it will return {@code null}.
*
* <p> If this method is invoked in a sub-topology due to a repartition, the returned key would be one sent
* to the repartition topic.
*
* @return the raw byte of the key of the source message
*/
byte[] sourceRawKey();

/**
* Return the non-deserialized byte[] of the input message value if the context has been triggered by a message.
*
* <p> If this method is invoked within a {@link Punctuator#punctuate(long)
* punctuation callback}, or while processing a record that was forwarded by a punctuation
* callback, it will return {@code null}.
*
* <p> If this method is invoked in a sub-topology due to a repartition, the returned key would be one sent
* to the repartition topic.
*
* @return the raw byte of the value of the source message
*/
byte[] sourceRawValue();
}
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,10 @@ public <K, V> void forward(final Record<K, V> record, final String childName) {
recordContext.offset(),
recordContext.partition(),
recordContext.topic(),
record.headers());
record.headers(),
recordContext.sourceRawKey(),
recordContext.sourceRawValue()
);
}

if (childName == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ public void process(final Record<KIn, VIn> record) {
internalProcessorContext.recordContext().headers(),
internalProcessorContext.currentNode().name(),
internalProcessorContext.taskId(),
internalProcessorContext.recordContext().timestamp()
internalProcessorContext.recordContext().timestamp(),
internalProcessorContext.recordContext().sourceRawKey(),
internalProcessorContext.recordContext().sourceRawValue()
);

final ProcessingExceptionHandler.ProcessingHandlerResponse response;
Expand Down
Loading