diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 066006ce7082e..d3b6572caa787 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2292,6 +2292,15 @@ object SQLConf { .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") .createWithDefault(2) + val FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS = + buildConf("spark.sql.streaming.flatMapGroupsWithState.skipEmittingInitialStateKeys") + .internal() + .doc("When true, the flatMapGroupsWithState operation in a streaming query will not emit " + + "results for the initial state keys of each group.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val CHECKPOINT_LOCATION = buildConf("spark.sql.streaming.checkpointLocation") .doc("The default location for storing checkpoint data for streaming queries.") .version("2.0.0") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 36e25773f8342..bf169c4c99ff4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -736,11 +736,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _, timeout, hasInitialState, stateGroupAttr, sda, sDeser, initialState, child) => val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) + val skipEmittingInitialStateKeys = + conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS) val execPlan = FlatMapGroupsWithStateExec( func, keyDeser, valueDeser, sDeser, groupAttr, stateGroupAttr, dataAttr, sda, outputAttr, None, stateEnc, stateVersion, outputMode, timeout, batchTimestampMs = None, eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None, - planLater(initialState), hasInitialState, planLater(child) + planLater(initialState), hasInitialState, skipEmittingInitialStateKeys, planLater(child) ) execPlan :: Nil case _ => @@ -828,7 +830,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val execPlan = python.FlatMapGroupsInPandasWithStateExec( func, groupAttr, outputAttr, stateType, None, stateVersion, outputMode, timeout, batchTimestampMs = None, eventTimeWatermarkForLateEvents = None, - eventTimeWatermarkForEviction = None, planLater(child) + eventTimeWatermarkForEviction = None, + skipEmittingInitialStateKeys = false, + planLater(child) ) execPlan :: Nil case _ => @@ -953,10 +957,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { f, keyDeserializer, valueDeserializer, grouping, data, output, stateEncoder, outputMode, isFlatMapGroupsWithState, timeout, hasInitialState, initialStateGroupAttrs, initialStateDataAttrs, initialStateDeserializer, initialState, child) => + val skipEmittingInitialStateKeys = + conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS) FlatMapGroupsWithStateExec.generateSparkPlanForBatchQueries( f, keyDeserializer, valueDeserializer, initialStateDeserializer, grouping, initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout, - hasInitialState, planLater(initialState), planLater(child) + hasInitialState, skipEmittingInitialStateKeys, planLater(initialState), planLater(child) ) :: Nil case logical.TransformWithState(keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, statefulProcessor, timeMode, outputMode, keyEncoder, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index eef0b3e3e8469..76bb164436624 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -50,6 +50,7 @@ import org.apache.spark.util.CompletionIterator * @param batchTimestampMs processing timestamp of the current batch. * @param eventTimeWatermarkForLateEvents event time watermark for filtering late events * @param eventTimeWatermarkForEviction event time watermark for state eviction + * @param skipEmittingInitialStateKeys whether to skip emitting initial state df keys * @param child logical plan of the underlying data */ case class FlatMapGroupsInPandasWithStateExec( @@ -64,6 +65,7 @@ case class FlatMapGroupsInPandasWithStateExec( batchTimestampMs: Option[Long], eventTimeWatermarkForLateEvents: Option[Long], eventTimeWatermarkForEviction: Option[Long], + skipEmittingInitialStateKeys: Boolean, child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase { // TODO(SPARK-40444): Add the support of initial state. @@ -137,7 +139,8 @@ case class FlatMapGroupsInPandasWithStateExec( override def processNewDataWithInitialState( childDataIter: Iterator[InternalRow], - initStateIter: Iterator[InternalRow]): Iterator[InternalRow] = { + initStateIter: Iterator[InternalRow], + skipEmittingInitialStateKeys: Boolean): Iterator[InternalRow] = { throw SparkUnsupportedOperationException() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 58d2a19989cbf..5fe3b0f82a0a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -52,6 +52,7 @@ trait FlatMapGroupsWithStateExecBase protected val initialStateDataAttrs: Seq[Attribute] protected val initialState: SparkPlan protected val hasInitialState: Boolean + protected val skipEmittingInitialStateKeys: Boolean val stateInfo: Option[StatefulOperatorStateInfo] protected val stateEncoder: ExpressionEncoder[Any] @@ -145,7 +146,8 @@ trait FlatMapGroupsWithStateExecBase val processedOutputIterator = initialStateIterOption match { case Some(initStateIter) if initStateIter.hasNext => - processor.processNewDataWithInitialState(filteredIter, initStateIter) + processor.processNewDataWithInitialState(filteredIter, initStateIter, + skipEmittingInitialStateKeys) case _ => processor.processNewData(filteredIter) } @@ -301,7 +303,8 @@ trait FlatMapGroupsWithStateExecBase */ def processNewDataWithInitialState( childDataIter: Iterator[InternalRow], - initStateIter: Iterator[InternalRow] + initStateIter: Iterator[InternalRow], + skipEmittingInitialStateKeys: Boolean ): Iterator[InternalRow] = { if (!childDataIter.hasNext && !initStateIter.hasNext) return Iterator.empty @@ -312,7 +315,8 @@ trait FlatMapGroupsWithStateExecBase val groupedInitialStateIter = GroupedIterator(initStateIter, initialStateGroupAttrs, initialState.output) - // Create a CoGroupedIterator that will group the two iterators together for every key group. + // Create a CoGroupedIterator that will group the two iterators together for every + // key group. new CoGroupedIterator( groupedChildDataIter, groupedInitialStateIter, groupingAttributes).flatMap { case (keyRow, valueRowIter, initialStateRowIter) => @@ -326,12 +330,17 @@ trait FlatMapGroupsWithStateExecBase val initStateObj = getStateObj.get(initialStateRow) stateManager.putState(store, keyUnsafeRow, initStateObj, NO_TIMESTAMP) } - // We apply the values for the key after applying the initial state. - callFunctionAndUpdateState( - stateManager.getState(store, keyUnsafeRow), + + if (skipEmittingInitialStateKeys && valueRowIter.isEmpty) { + // If the user has specified to skip emitting the keys that only have initial state + // and no data, then we should not call the function for such keys. + Iterator.empty + } else { + callFunctionAndUpdateState( + stateManager.getState(store, keyUnsafeRow), valueRowIter, - hasTimedOut = false - ) + hasTimedOut = false) + } } } @@ -388,6 +397,7 @@ trait FlatMapGroupsWithStateExecBase * @param eventTimeWatermarkForEviction event time watermark for state eviction * @param initialState the user specified initial state * @param hasInitialState indicates whether the initial state is provided or not + * @param skipEmittingInitialStateKeys whether to skip emitting initial state df keys * @param child the physical plan for the underlying data */ case class FlatMapGroupsWithStateExec( @@ -410,6 +420,7 @@ case class FlatMapGroupsWithStateExec( eventTimeWatermarkForEviction: Option[Long], initialState: SparkPlan, hasInitialState: Boolean, + skipEmittingInitialStateKeys: Boolean, child: SparkPlan) extends FlatMapGroupsWithStateExecBase with BinaryExecNode with ObjectProducerExec { import GroupStateImpl._ @@ -533,6 +544,7 @@ object FlatMapGroupsWithStateExec { outputObjAttr: Attribute, timeoutConf: GroupStateTimeout, hasInitialState: Boolean, + skipEmittingInitialStateKeys: Boolean, initialState: SparkPlan, child: SparkPlan): SparkPlan = { if (hasInitialState) { @@ -541,27 +553,31 @@ object FlatMapGroupsWithStateExec { case _ => false } val func = (keyRow: Any, values: Iterator[Any], states: Iterator[Any]) => { - // Check if there is only one state for every key. - var foundInitialStateForKey = false - val optionalStates = states.map { stateValue => - if (foundInitialStateForKey) { - foundDuplicateInitialKeyException() - } - foundInitialStateForKey = true - stateValue - }.toArray - - // Create group state object - val groupState = GroupStateImpl.createForStreaming( - optionalStates.headOption, - System.currentTimeMillis, - GroupStateImpl.NO_TIMESTAMP, - timeoutConf, - hasTimedOut = false, - watermarkPresent) - - // Call user function with the state and values for this key - userFunc(keyRow, values, groupState) + if (skipEmittingInitialStateKeys && values.isEmpty) { + Iterator.empty + } else { + // Check if there is only one state for every key. + var foundInitialStateForKey = false + val optionalStates = states.map { stateValue => + if (foundInitialStateForKey) { + foundDuplicateInitialKeyException() + } + foundInitialStateForKey = true + stateValue + }.toArray + + // Create group state object + val groupState = GroupStateImpl.createForStreaming( + optionalStates.headOption, + System.currentTimeMillis, + GroupStateImpl.NO_TIMESTAMP, + timeoutConf, + hasTimedOut = false, + watermarkPresent) + + // Call user function with the state and values for this key + userFunc(keyRow, values, groupState) + } } CoGroupExec( func, keyDeserializer, valueDeserializer, initialStateDeserializer, groupingAttributes, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index f7ff39622ed40..f1feb62b7622a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -1177,6 +1177,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { Some(currentBatchTimestamp), Some(0), Some(currentBatchWatermark), RDDScanExec(g, emptyRdd, "rdd"), hasInitialState, + false, RDDScanExec(g, emptyRdd, "rdd")) }.get } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala index 2a2a83d35e1f8..dd4e3615d43ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala @@ -351,6 +351,135 @@ class FlatMapGroupsWithStateWithInitialStateSuite extends StateStoreMetricsTest ) } + // if the keys part of initial state df are different than the keys in the input data, then + // they will not be emitted as part of the result with skipEmittingInitialStateKeys set to true + testWithAllStateVersions("flatMapGroupsWithState - initial state - " + + s"skipEmittingInitialStateKeys=true") { + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS.key -> "true") { + val initialState = Seq( + ("apple", 1L), + ("orange", 2L), + ("mango", 5L)).toDS().groupByKey(_._1).mapValues(_._2) + + val fruitCountFunc = (key: String, values: Iterator[String], state: GroupState[Long]) => { + val count = state.getOption.map( x => x).getOrElse(0L) + values.size + state.update(count) + Iterator.single((key, count)) + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, NoTimeout(), initialState)(fruitCountFunc) + testStream(result, Update)( + AddData(inputData, "apple"), + AddData(inputData, "banana"), + CheckNewAnswer(("apple", 2), ("banana", 1)), + AddData(inputData, "orange"), + CheckNewAnswer(("orange", 3)), + StopStream + ) + } + } + + // if the keys part of initial state df are different than the keys in the input data, then + // they will be emitted as part of the result with skipEmittedInitialStateKeys set to false + testWithAllStateVersions("flatMapGroupsWithState - initial state - " + + s"skipEmittingInitialStateKeys=false") { + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS.key -> "false") { + val initialState = Seq( + ("apple", 1L), + ("orange", 2L), + ("mango", 5L)).toDS().groupByKey(_._1).mapValues(_._2) + + val fruitCountFunc = (key: String, values: Iterator[String], state: GroupState[Long]) => { + val count = state.getOption.map( x => x).getOrElse(0L) + values.size + state.update(count) + Iterator.single((key, count)) + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, NoTimeout(), initialState)(fruitCountFunc) + testStream(result, Update)( + AddData(inputData, "apple"), + AddData(inputData, "banana"), + CheckNewAnswer(("apple", 2), ("banana", 1), ("orange", 2), ("mango", 5)), + AddData(inputData, "orange"), + CheckNewAnswer(("orange", 3)), + StopStream + ) + } + } + + // if the keys part of the initial state and the first batch are the same, then the result + // is the same irrespective of the value of skipEmittingInitialStateKeys + Seq(true, false).foreach { skipEmittingInitialStateKeys => + testWithAllStateVersions("flatMapGroupsWithState - initial state and initial batch " + + s"have same keys and skipEmittingInitialStateKeys=$skipEmittingInitialStateKeys") { + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS.key -> + skipEmittingInitialStateKeys.toString) { + val initialState = Seq( + ("apple", 1L), + ("orange", 2L)).toDS().groupByKey(_._1).mapValues(_._2) + + val fruitCountFunc = (key: String, values: Iterator[String], state: GroupState[Long]) => { + val count = state.getOption.map(x => x).getOrElse(0L) + values.size + state.update(count) + Iterator.single((key, count)) + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, NoTimeout(), initialState)(fruitCountFunc) + testStream(result, Update)( + AddData(inputData, "apple"), + AddData(inputData, "apple"), + AddData(inputData, "orange"), + CheckNewAnswer(("apple", 3), ("orange", 3)), + AddData(inputData, "orange"), + CheckNewAnswer(("orange", 4)), + StopStream + ) + } + } + } + + Seq(true, false).foreach { skipEmittingInitialStateKeys => + testWithAllStateVersions("flatMapGroupsWithState - batch query and " + + s"skipEmittingInitialStateKeys=$skipEmittingInitialStateKeys") { + withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS.key -> + skipEmittingInitialStateKeys.toString) { + val initialState = Seq( + ("apple", 1L), + ("orange", 2L)).toDS().groupByKey(_._1).mapValues(_._2) + + val fruitCountFunc = (key: String, values: Iterator[String], state: GroupState[Long]) => { + val count = state.getOption.map(x => x).getOrElse(0L) + values.size + state.update(count) + Iterator.single((key, count)) + } + + val inputData = Seq("orange", "mango") + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, NoTimeout(), initialState)(fruitCountFunc) + val df = result.toDF() + if (skipEmittingInitialStateKeys) { + checkAnswer(df, Seq(("orange", 3), ("mango", 1)).toDF()) + } else { + checkAnswer(df, Seq(("apple", 1), ("orange", 3), ("mango", 1)).toDF()) + } + } + } + } + def testWithAllStateVersions(name: String)(func: => Unit): Unit = { for (version <- FlatMapGroupsWithStateExecHelper.supportedVersions) { test(s"$name - state format version $version") {