diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index a76e63df873e7..eb6bec4505058 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -22,6 +22,8 @@ import java.time.Duration import java.util.UUID import org.apache.hadoop.fs.{FileStatus, Path} +import org.scalactic.source.Position +import org.scalatest.Tag import org.scalatest.matchers.must.Matchers.be import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.scalatest.time.{Seconds, Span} @@ -736,8 +738,8 @@ class SleepingTimerProcessor extends StatefulProcessor[String, String, String] { * Class that adds tests for transformWithState stateful streaming operator */ @SlowSQLTest -class TransformWithStateSuite extends StateStoreMetricsTest - with AlsoTestWithEncodingTypes with AlsoTestWithRocksDBFeatures { +abstract class TransformWithStateTest extends StateStoreMetricsTest + with AlsoTestWithRocksDBFeatures { import testImplicits._ @@ -924,1442 +926,1501 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - testWithEncoding("avro")("transformWithState - value schema threshold exceeded") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, - SQLConf.STREAMING_VALUE_STATE_SCHEMA_EVOLUTION_THRESHOLD.key -> "0") { - withTempDir { chkptDir => - val dirPath = chkptDir.getCanonicalPath - val inputData = MemoryStream[String] - val result1 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessorInt(), - TimeMode.None(), - OutputMode.Update()) - - testStream(result1, OutputMode.Update())( - StartStream(checkpointLocation = dirPath), - AddData(inputData, "a"), - CheckNewAnswer(("a", "1")), - Execute { q => - assert(q.lastProgress.stateOperators(0).customMetrics.get("numValueStateVars") > 0) - assert(q.lastProgress.stateOperators(0).customMetrics.get("numRegisteredTimers") == 0) - assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) - }, - AddData(inputData, "a", "b"), - CheckNewAnswer(("a", "2"), ("b", "1")), - StopStream, - StartStream(checkpointLocation = dirPath), - AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckNewAnswer(("b", "2")), - StopStream, - Execute { q => - assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) - assert(q.lastProgress.stateOperators(0).numRowsRemoved === 1) - }, - StartStream(checkpointLocation = dirPath), - AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckNewAnswer(("a", "1"), ("c", "1")) - ) - - val result2 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), - TimeMode.None(), - OutputMode.Update()) - - testStream(result2, OutputMode.Update())( - StartStream(checkpointLocation = dirPath), - AddData(inputData, "a"), - ExpectFailure[StateStoreValueSchemaEvolutionThresholdExceeded] { t => - checkError( - t.asInstanceOf[StateStoreValueSchemaEvolutionThresholdExceeded], - condition = "STATE_STORE_VALUE_SCHEMA_EVOLUTION_THRESHOLD_EXCEEDED", - parameters = Map( - "numSchemaEvolutions" -> "1", - "maxSchemaEvolutions" -> "0", - "colFamilyName" -> "countState" - ) - ) - } - ) - } - } - } - - testWithEncoding("avro")("transformWithState - upcasting should succeed") { + test("transformWithState - streaming with rocksdb and processing time timer " + + "and updating timers should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - withTempDir { chkptDir => - val dirPath = chkptDir.getCanonicalPath - val inputData = MemoryStream[String] - val result1 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessorInt(), - TimeMode.None(), - OutputMode.Update()) - - testStream(result1, OutputMode.Update())( - StartStream(checkpointLocation = dirPath), - AddData(inputData, "a"), - CheckNewAnswer(("a", "1")), - Execute { q => - assert(q.lastProgress.stateOperators(0).customMetrics.get("numValueStateVars") > 0) - assert(q.lastProgress.stateOperators(0).customMetrics.get("numRegisteredTimers") == 0) - assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) - }, - AddData(inputData, "a", "b"), - CheckNewAnswer(("a", "2"), ("b", "1")), - StopStream, - StartStream(checkpointLocation = dirPath), - AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckNewAnswer(("b", "2")), - StopStream, - Execute { q => - assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) - assert(q.lastProgress.stateOperators(0).numRowsRemoved === 1) - }, - StartStream(checkpointLocation = dirPath), - AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckNewAnswer(("a", "1"), ("c", "1")) - ) - - val result2 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), - TimeMode.None(), - OutputMode.Update()) - - testStream(result2, OutputMode.Update())( - StartStream(checkpointLocation = dirPath), - AddData(inputData, "a"), - CheckNewAnswer(("a", "2")), - AddData(inputData, "d"), - CheckNewAnswer(("d", "1")), - StopStream - ) - } - } - } - - testWithEncoding("avro")("transformWithState - reordering fields should succeed") { - withSQLConf( - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - withTempDir { chkptDir => - val dirPath = chkptDir.getCanonicalPath - val inputData = MemoryStream[String] + classOf[RocksDBStateStoreProvider].getName) { + val clock = new StreamManualClock - // First run with initial field order - val result1 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessorInitialOrder(), - TimeMode.None(), - OutputMode.Update()) + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState( + new RunningCountStatefulProcessorWithProcTimeTimerUpdates(), + TimeMode.ProcessingTime(), + OutputMode.Update()) - testStream(result1, OutputMode.Update())( - StartStream(checkpointLocation = dirPath), - AddData(inputData, "a"), - CheckNewAnswer(("a", "1")), - StopStream - ) + testStream(result, OutputMode.Update())( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")), // at batch 0, ts = 1, timer = "a" -> [6] (= 1 + 5) - // Second run with reordered fields - val result2 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessorReorderedFields(), - TimeMode.None(), - OutputMode.Update()) + AddData(inputData, "a"), + AdvanceManualClock(2 * 1000), + CheckNewAnswer(("a", "2")), // at batch 1, ts = 3, timer = "a" -> [9.5] (2 + 7.5) + StopStream, - testStream(result2, OutputMode.Update())( - StartStream(checkpointLocation = dirPath), - AddData(inputData, "a"), - CheckNewAnswer(("a", "2")), // Should continue counting from previous state - StopStream - ) - } + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, "d"), + AdvanceManualClock(10 * 1000), + CheckNewAnswer(("a", "-1"), ("d", "1")), // at batch 2, ts = 13, timer for "a" is expired. + // If the timer of "a" was not replaced (pure addition), it would have triggered the timer + // two times here and produced ("a", "-1") two times. + StopStream + ) } } - testWithEncoding("avro")("transformWithState - adding field should succeed") { + test("transformWithState - streaming with rocksdb and processing time timer " + + "and multiple timers should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - withTempDir { chkptDir => - val dirPath = chkptDir.getCanonicalPath - val inputData = MemoryStream[String] - val result1 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), - TimeMode.None(), - OutputMode.Update()) - - testStream(result1, OutputMode.Update())( - StartStream(checkpointLocation = dirPath), - AddData(inputData, "a"), - CheckNewAnswer(("a", "1")), - Execute { q => - assert(q.lastProgress.stateOperators(0).customMetrics.get("numValueStateVars") > 0) - assert(q.lastProgress.stateOperators(0).customMetrics.get("numRegisteredTimers") == 0) - assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) - }, - AddData(inputData, "a", "b"), - CheckNewAnswer(("a", "2"), ("b", "1")), - StopStream, - StartStream(checkpointLocation = dirPath), - AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckNewAnswer(("b", "2")), - StopStream, - Execute { q => - assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) - assert(q.lastProgress.stateOperators(0).numRowsRemoved === 1) - }, - StartStream(checkpointLocation = dirPath), - AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckNewAnswer(("a", "1"), ("c", "1")) - ) - - val result2 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessorNestedLongs(), - TimeMode.None(), - OutputMode.Update()) - - testStream(result2, OutputMode.Update())( - StartStream(checkpointLocation = dirPath), - AddData(inputData, "a"), - CheckNewAnswer(("a", "2")), - StopStream - ) - } - } - } + classOf[RocksDBStateStoreProvider].getName) { + val clock = new StreamManualClock - testWithEncoding("avro")("transformWithState - add and remove field between runs") { - withSQLConf( - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - withTempDir { dir => - val inputData = MemoryStream[String] + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState( + new RunningCountStatefulProcessorWithMultipleTimers(), + TimeMode.ProcessingTime(), + OutputMode.Update()) - // First run with original field names - val result1 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessorInitialOrder(), - TimeMode.None(), - OutputMode.Update()) + testStream(result, OutputMode.Update())( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), // at batch 0, add 3 timers for given key = "a" - testStream(result1, OutputMode.Update())( - StartStream(checkpointLocation = dir.getCanonicalPath), - AddData(inputData, "test1"), - CheckNewAnswer(("test1", "1")), - StopStream - ) + AddData(inputData, "a"), + AdvanceManualClock(6 * 1000), + CheckNewAnswer(("a", "2")), // at ts = 7, first timer expires and produces ("a", "2") - // Second run with renamed field - val result2 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RenameEvolvedProcessor(), - TimeMode.None(), - OutputMode.Update()) + AddData(inputData, "a"), + AdvanceManualClock(5 * 1000), + CheckNewAnswer(("a", "3")), // at ts = 12, second timer expires and produces ("a", "3") + StopStream, - testStream(result2, OutputMode.Update())( - StartStream(checkpointLocation = dir.getCanonicalPath), - // Uses default value, does not factor previous value1 into this - AddData(inputData, "test1"), - CheckNewAnswer(("test1", "1")), - // Verify we can write state with new field name - AddData(inputData, "test2"), - CheckNewAnswer(("test2", "1")), - StopStream - ) - } + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(5 * 1000), + CheckNewAnswer(("a", "4")), // at ts = 17, third timer expires and produces ("a", "4") + StopStream + ) } } - testWithEncoding("avro")("state data source - schema evolution with time travel support") { - withSQLConf( - rocksdbChangelogCheckpointingConfKey -> "true", - SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> "1", - SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1") { - - withTempDir { chkptDir => - val dirPath = chkptDir.getCanonicalPath - val inputData = MemoryStream[String] - - val result1 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessorTwoLongs(), - TimeMode.None(), - OutputMode.Update()) - - testStream(result1, OutputMode.Update())( - StartStream(checkpointLocation = dirPath), - AddData(inputData, "a"), - CheckNewAnswer(("a", "1")), - AddData(inputData, "b"), - CheckNewAnswer(("b", "1")), - ProcessAllAvailable(), - Execute { _ => Thread.sleep(5000) }, - StopStream - ) - - val result2 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RenameEvolvedProcessor(), - TimeMode.None(), - OutputMode.Update()) - - testStream(result2, OutputMode.Update())( - StartStream(checkpointLocation = dirPath), - AddData(inputData, "c"), - CheckNewAnswer(("c", "1")), - AddData(inputData, "d"), - CheckNewAnswer(("d", "1")), - ProcessAllAvailable(), - Execute { _ => Thread.sleep(5000) }, - StopStream - ) - - val oldStateDf = spark.read - .format("statestore") - .option("snapshotStartBatchId", 0) - .option("batchId", 1) - .option("snapshotPartitionId", 0) - .option(StateSourceOptions.STATE_VAR_NAME, "countState") - .load(dirPath) + test("transformWithState - streaming with rocksdb and event " + + "time based timer") { + val inputData = MemoryStream[(String, Int)] + val result = + inputData.toDS() + .select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime")) + .withWatermark("eventTime", "10 seconds") + .as[(String, Long)] + .groupByKey(_._1) + .transformWithState( + new MaxEventTimeStatefulProcessor(), + TimeMode.EventTime(), + OutputMode.Update()) - checkAnswer( - oldStateDf.selectExpr( - "key.value AS groupingKey", - "value.value1 AS count"), - Seq(Row("a", 1), Row("b", 1)) - ) + testStream(result, OutputMode.Update())( + StartStream(), - val evolvedStateDf1 = spark.read - .format("statestore") - .option("snapshotStartBatchId", 0) - .option("batchId", 3) - .option("snapshotPartitionId", 0) - .option(StateSourceOptions.STATE_VAR_NAME, "countState") - .load(dirPath) + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), + // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. + CheckNewAnswer(("a", 15)), // Output = max event time of a - checkAnswer( - evolvedStateDf1.selectExpr( - "key.value AS groupingKey", - "value.value4 AS count"), - Seq( - Row("a", null), - Row("b", null), - Row("c", 1), - Row("d", 1) - ) - ) + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" + CheckNewAnswer(), // No output as data should get filtered by watermark - val evolvedStateDf = spark.read - .format("statestore") - .option("snapshotStartBatchId", 3) - .option("snapshotPartitionId", 0) - .option(StateSourceOptions.STATE_VAR_NAME, "countState") - .load(dirPath) + AddData(inputData, ("a", 10)), // Add data newer than watermark for "a" + CheckNewAnswer(("a", 15)), // Max event time is still the same + // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. + // Watermark is still 5 as max event time for all data is still 15. - checkAnswer( - evolvedStateDf.selectExpr( - "key.value AS groupingKey", - "value.value4 AS count"), - Seq( - Row("a", null), - Row("b", null), - Row("c", 1), - Row("d", 1) - ) - ) + AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" + // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20. + CheckNewAnswer(("a", -1), ("b", 31)), // State for "a" should timeout and emit -1 + Execute { q => + // Filter for idle progress events and then verify the custom metrics for stateful operator + val progData = q.recentProgress.filter(prog => prog.stateOperators.size > 0) + assert(progData.filter(prog => + prog.stateOperators(0).customMetrics.get("numValueStateVars") > 0).size > 0) + assert(progData.filter(prog => + prog.stateOperators(0).customMetrics.get("numRegisteredTimers") > 0).size > 0) + assert(progData.filter(prog => + prog.stateOperators(0).customMetrics.get("numDeletedTimers") > 0).size > 0) } - } + ) } - testWithEncoding("avro")("transformWithState - verify default values during schema evolution") { - withSQLConf( - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - withTempDir { dir => - val inputData = MemoryStream[String] - - // First run with basic schema - val result1 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new DefaultValueInitialProcessor(), - TimeMode.None(), - OutputMode.Update()) + test("transformWithState - timer duration should be reflected in metrics") { + val clock = new StreamManualClock + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState( + new SleepingTimerProcessor, TimeMode.ProcessingTime(), OutputMode.Update()) - testStream(result1, OutputMode.Update())( - StartStream(checkpointLocation = dir.getCanonicalPath), - AddData(inputData, "test1"), - CheckNewAnswer(("test1", BasicState("test1".hashCode, "test1"))), - StopStream - ) + testStream(result, OutputMode.Update())( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + // Side effect: timer scheduled for t = 1 + 10 = 11. + CheckNewAnswer(), + Execute { q => + val metrics = q.lastProgress.stateOperators(0).customMetrics + assert(metrics.get("numRegisteredTimers") === 1) + assert(metrics.get("timerProcessingTimeMs") < 2000) + }, - // Second run with evolved schema to check defaults - val result2 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new DefaultValueEvolvedProcessor(), - TimeMode.None(), - OutputMode.Update()) + AddData(inputData, "b"), + AdvanceManualClock(1 * 1000), + // Side effect: timer scheduled for t = 2 + 10 = 12. + CheckNewAnswer(), + Execute { q => + val metrics = q.lastProgress.stateOperators(0).customMetrics + assert(metrics.get("numRegisteredTimers") === 1) + assert(metrics.get("timerProcessingTimeMs") < 2000) + }, - testStream(result2, OutputMode.Update())( - StartStream(checkpointLocation = dir.getCanonicalPath), + AddData(inputData, "c"), + // Time is currently 2 and we need to advance past 12. So, advance by 11 seconds. + AdvanceManualClock(11 * 1000), + CheckNewAnswer("a", "b"), + Execute { q => + val metrics = q.lastProgress.stateOperators(0).customMetrics + assert(metrics.get("numRegisteredTimers") === 1) - // Check existing state - new fields should get default values - AddData(inputData, "test1"), - CheckNewAnswer( - ("test1", EvolvedState( - id = "test1".hashCode, - name = "test1", - count = 0L, - active = false, - score = 0.0 - )) - ), + // Both timers should have fired and taken 1 second each to process. + assert(metrics.get("timerProcessingTimeMs") >= 2000) + }, - // New state should get initialized values, not defaults - AddData(inputData, "test2"), - CheckNewAnswer( - ("test2", EvolvedState( - id = "test2".hashCode, - name = "test2", - count = 100L, - active = true, - score = 99.9 - )) - ), - StopStream - ) - } + StopStream + ) + } + + test("Use statefulProcessor without transformWithState -" + + " handle should be absent") { + val processor = new RunningCountStatefulProcessor() + val ex = intercept[Exception] { + processor.getHandle } + checkError( + ex.asInstanceOf[SparkRuntimeException], + condition = "STATE_STORE_HANDLE_NOT_INITIALIZED", + parameters = Map.empty + ) + } + + test("transformWithState - batch should succeed") { + val inputData = Seq("a", "b") + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Append()) + + val df = result.toDF() + checkAnswer(df, Seq(("a", "1"), ("b", "1")).toDF()) } - testWithEncoding("avro")("transformWithState - removing field should succeed") { + test("transformWithState - test deleteIfExists operator") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { withTempDir { chkptDir => val dirPath = chkptDir.getCanonicalPath - val inputData = MemoryStream[String] + val inputData = MemoryStream[(String, String)] + val stream1 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) - val result2 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessorTwoLongs(), + val stream2 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new MostRecentStatefulProcessorWithDeletion(), TimeMode.None(), OutputMode.Update()) - testStream(result2, OutputMode.Update())( + testStream(stream1, OutputMode.Update())( StartStream(checkpointLocation = dirPath), - AddData(inputData, "a"), - CheckNewAnswer(("a", "1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "1", "")), StopStream ) - - val result1 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), - TimeMode.None(), - OutputMode.Update()) - - testStream(result1, OutputMode.Update())( + testStream(stream2, OutputMode.Update())( StartStream(checkpointLocation = dirPath), - AddData(inputData, "a"), - CheckNewAnswer(("a", "1")), + AddData(inputData, ("a", "str2"), ("b", "str3")), + CheckNewAnswer(("a", "str1"), + ("b", "")), // should not factor in previous count state + Execute { q => + assert(q.lastProgress.stateOperators(0).customMetrics.get("numValueStateVars") > 0) + assert(q.lastProgress.stateOperators(0).customMetrics.get("numDeletedStateVars") > 0) + }, StopStream ) } } } - test("transformWithState - streaming with rocksdb and processing time timer " + - "and updating timers should succeed") { + test("transformWithState - two input streams") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName) { - val clock = new StreamManualClock + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + val inputData1 = MemoryStream[String] + val inputData2 = MemoryStream[String] - val inputData = MemoryStream[String] - val result = inputData.toDS() + val result = inputData1.toDS() + .union(inputData2.toDS()) .groupByKey(x => x) - .transformWithState( - new RunningCountStatefulProcessorWithProcTimeTimerUpdates(), - TimeMode.ProcessingTime(), + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), OutputMode.Update()) testStream(result, OutputMode.Update())( - StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), - AddData(inputData, "a"), - AdvanceManualClock(1 * 1000), - CheckNewAnswer(("a", "1")), // at batch 0, ts = 1, timer = "a" -> [6] (= 1 + 5) + AddData(inputData1, "a"), + CheckNewAnswer(("a", "1")), + AddData(inputData2, "a", "b"), + CheckNewAnswer(("a", "2"), ("b", "1")), + AddData(inputData1, "a", "b"), // should remove state for "a" and not return anything for a + CheckNewAnswer(("b", "2")), + AddData(inputData1, "d", "e"), + AddData(inputData2, "a", "c"), // should recreate state for "a" and return count as 1 + CheckNewAnswer(("a", "1"), ("c", "1"), ("d", "1"), ("e", "1")), + StopStream + ) + } + } - AddData(inputData, "a"), - AdvanceManualClock(2 * 1000), - CheckNewAnswer(("a", "2")), // at batch 1, ts = 3, timer = "a" -> [9.5] (2 + 7.5) - StopStream, + test("transformWithState - three input streams") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + val inputData1 = MemoryStream[String] + val inputData2 = MemoryStream[String] + val inputData3 = MemoryStream[String] - StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), - AddData(inputData, "d"), - AdvanceManualClock(10 * 1000), - CheckNewAnswer(("a", "-1"), ("d", "1")), // at batch 2, ts = 13, timer for "a" is expired. - // If the timer of "a" was not replaced (pure addition), it would have triggered the timer - // two times here and produced ("a", "-1") two times. + // union 3 input streams + val result = inputData1.toDS() + .union(inputData2.toDS()) + .union(inputData3.toDS()) + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + AddData(inputData1, "a"), + CheckNewAnswer(("a", "1")), + AddData(inputData2, "a", "b"), + CheckNewAnswer(("a", "2"), ("b", "1")), + AddData(inputData3, "a", "b"), // should remove state for "a" and not return anything for a + CheckNewAnswer(("b", "2")), + AddData(inputData1, "d", "e"), + AddData(inputData2, "a", "c"), // should recreate state for "a" and return count as 1 + CheckNewAnswer(("a", "1"), ("c", "1"), ("d", "1"), ("e", "1")), + AddData(inputData3, "a", "c", "d", "e"), + CheckNewAnswer(("a", "2"), ("c", "2"), ("d", "2"), ("e", "2")), + StopStream + ) + } + } + + test("transformWithState - two input streams, different key type") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + val inputData1 = MemoryStream[String] + val inputData2 = MemoryStream[Long] + + val result = inputData1.toDS() + // union inputData2 by casting it to a String + .union(inputData2.toDS().map(_.toString)) + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + AddData(inputData1, "1"), + CheckNewAnswer(("1", "1")), + AddData(inputData2, 1L, 2L), + CheckNewAnswer(("1", "2"), ("2", "1")), + AddData(inputData1, "1", "2"), // should remove state for "1" and not return anything. + CheckNewAnswer(("2", "2")), + AddData(inputData1, "4", "5"), + AddData(inputData2, 1L, 3L), // should recreate state for "1" and return count as 1 + CheckNewAnswer(("1", "1"), ("3", "1"), ("4", "1"), ("5", "1")), StopStream ) } } - test("transformWithState - streaming with rocksdb and processing time timer " + - "and multiple timers should succeed") { + /** Create a text file with a single data item */ + private def createFile(data: String, srcDir: File): File = + stringToFile(new File(srcDir, s"${UUID.randomUUID()}.txt"), data) + + private def createFileStream(srcDir: File): Dataset[(String, String)] = { + spark + .readStream + .option("maxFilesPerTrigger", "1") + .text(srcDir.getCanonicalPath) + .select("value").as[String] + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + } + + test("transformWithState - availableNow trigger mode, rate limit is respected") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + withTempDir { srcDir => + + Seq("a", "b", "c").foreach(createFile(_, srcDir)) + + // Set up a query to read text files one at a time + val df = createFileStream(srcDir) + + testStream(df)( + StartStream(trigger = Trigger.AvailableNow()), + ProcessAllAvailable(), + CheckNewAnswer(("a", "1"), ("b", "1"), ("c", "1")), + StopStream, + Execute { _ => + createFile("a", srcDir) + }, + StartStream(trigger = Trigger.AvailableNow()), + ProcessAllAvailable(), + CheckNewAnswer(("a", "2")) + ) + + var index = 0 + val foreachBatchDf = df.writeStream + .foreachBatch((_: Dataset[(String, String)], _: Long) => { + index += 1 + }) + .trigger(Trigger.AvailableNow()) + .start() + + try { + foreachBatchDf.awaitTermination() + assert(index == 4) + } finally { + foreachBatchDf.stop() + } + } + } + } + + test("transformWithState - availableNow trigger mode, multiple restarts") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { - val clock = new StreamManualClock - - val inputData = MemoryStream[String] - val result = inputData.toDS() - .groupByKey(x => x) - .transformWithState( - new RunningCountStatefulProcessorWithMultipleTimers(), - TimeMode.ProcessingTime(), - OutputMode.Update()) + withTempDir { srcDir => + Seq("a", "b", "c").foreach(createFile(_, srcDir)) + val df = createFileStream(srcDir) - testStream(result, OutputMode.Update())( - StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), - AddData(inputData, "a"), - AdvanceManualClock(1 * 1000), // at batch 0, add 3 timers for given key = "a" + var index = 0 - AddData(inputData, "a"), - AdvanceManualClock(6 * 1000), - CheckNewAnswer(("a", "2")), // at ts = 7, first timer expires and produces ("a", "2") + def startTriggerAvailableNowQueryAndCheck(expectedIdx: Int): Unit = { + val q = df.writeStream + .foreachBatch((_: Dataset[(String, String)], _: Long) => { + index += 1 + }) + .trigger(Trigger.AvailableNow) + .start() + try { + assert(q.awaitTermination(streamingTimeout.toMillis)) + assert(index == expectedIdx) + } finally { + q.stop() + } + } + // start query for the first time + startTriggerAvailableNowQueryAndCheck(3) - AddData(inputData, "a"), - AdvanceManualClock(5 * 1000), - CheckNewAnswer(("a", "3")), // at ts = 12, second timer expires and produces ("a", "3") - StopStream, + // add two files and restart + createFile("a", srcDir) + createFile("b", srcDir) + startTriggerAvailableNowQueryAndCheck(8) - StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), - AddData(inputData, "a"), - AdvanceManualClock(5 * 1000), - CheckNewAnswer(("a", "4")), // at ts = 17, third timer expires and produces ("a", "4") - StopStream - ) + // try restart again + createFile("d", srcDir) + startTriggerAvailableNowQueryAndCheck(14) + } } } - test("transformWithState - streaming with rocksdb and event " + - "time based timer") { - val inputData = MemoryStream[(String, Int)] - val result = - inputData.toDS() - .select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime")) - .withWatermark("eventTime", "10 seconds") - .as[(String, Long)] - .groupByKey(_._1) - .transformWithState( - new MaxEventTimeStatefulProcessor(), - TimeMode.EventTime(), - OutputMode.Update()) + Seq(false, true).foreach { useImplicits => + test("transformWithState - verify StateSchemaV3 writes " + + s"correct SQL schema of key/value with useImplicits=$useImplicits") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + // When Avro is used, we want to set the StructFields to nullable + val shouldBeNullable = usingAvroEncoding() + val metadataPathPostfix = "state/0/_stateSchema/default" + val stateSchemaPath = new Path(checkpointDir.toString, + s"$metadataPathPostfix") + val hadoopConf = spark.sessionState.newHadoopConf() + val fm = CheckpointFileManager.create(stateSchemaPath, hadoopConf) - testStream(result, OutputMode.Update())( - StartStream(), + val keySchema = new StructType().add("value", StringType) + val schema0 = StateStoreColFamilySchema( + "countState", 0, + keySchema, 0, + new StructType().add("value", LongType, nullable = shouldBeNullable), + Some(NoPrefixKeyStateEncoderSpec(keySchema)), + None + ) + val schema1 = StateStoreColFamilySchema( + "listState", 0, + keySchema, 0, + new StructType() + .add("id", LongType, nullable = shouldBeNullable) + .add("name", StringType), + Some(NoPrefixKeyStateEncoderSpec(keySchema)), + None + ) - AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), - // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. - CheckNewAnswer(("a", 15)), // Output = max event time of a + val userKeySchema = new StructType() + .add("id", IntegerType, false) + .add("name", StringType) + val compositeKeySchema = new StructType() + .add("key", new StructType().add("value", StringType)) + .add("userKey", userKeySchema) + val schema2 = StateStoreColFamilySchema( + "mapState", 0, + compositeKeySchema, 0, + new StructType().add("value", StringType), + Some(PrefixKeyScanStateEncoderSpec(compositeKeySchema, 1)), + Option(userKeySchema) + ) - AddData(inputData, ("a", 4)), // Add data older than watermark for "a" - CheckNewAnswer(), // No output as data should get filtered by watermark + val schema3 = StateStoreColFamilySchema( + "$rowCounter_listState", 0, + keySchema, 0, + new StructType().add("count", LongType, nullable = shouldBeNullable), + Some(NoPrefixKeyStateEncoderSpec(keySchema)), + None + ) - AddData(inputData, ("a", 10)), // Add data newer than watermark for "a" - CheckNewAnswer(("a", 15)), // Max event time is still the same - // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. - // Watermark is still 5 as max event time for all data is still 15. + val schema4 = StateStoreColFamilySchema( + "default", 0, + keySchema, 0, + new StructType().add("value", BinaryType), + Some(NoPrefixKeyStateEncoderSpec(keySchema)), + None + ) - AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" - // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20. - CheckNewAnswer(("a", -1), ("b", 31)), // State for "a" should timeout and emit -1 - Execute { q => - // Filter for idle progress events and then verify the custom metrics for stateful operator - val progData = q.recentProgress.filter(prog => prog.stateOperators.size > 0) - assert(progData.filter(prog => - prog.stateOperators(0).customMetrics.get("numValueStateVars") > 0).size > 0) - assert(progData.filter(prog => - prog.stateOperators(0).customMetrics.get("numRegisteredTimers") > 0).size > 0) - assert(progData.filter(prog => - prog.stateOperators(0).customMetrics.get("numDeletedTimers") > 0).size > 0) - } - ) - } + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new StatefulProcessorWithCompositeTypes(useImplicits), + TimeMode.None(), + OutputMode.Update()) - test("transformWithState - timer duration should be reflected in metrics") { - val clock = new StreamManualClock - val inputData = MemoryStream[String] - val result = inputData.toDS() - .groupByKey(x => x) - .transformWithState( - new SleepingTimerProcessor, TimeMode.ProcessingTime(), OutputMode.Update()) + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "1"), ("b", "1")), + Execute { q => + q.lastProgress.runId + val schemaFilePath = fm.list(stateSchemaPath).toSeq.head.getPath + val providerId = StateStoreProviderId(StateStoreId( + checkpointDir.getCanonicalPath, 0, 0), q.lastProgress.runId) + val checker = new StateSchemaCompatibilityChecker(providerId, + hadoopConf, List(schemaFilePath)) + val colFamilySeq = checker.readSchemaFile() - testStream(result, OutputMode.Update())( - StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), - AddData(inputData, "a"), - AdvanceManualClock(1 * 1000), - // Side effect: timer scheduled for t = 1 + 10 = 11. - CheckNewAnswer(), - Execute { q => - val metrics = q.lastProgress.stateOperators(0).customMetrics - assert(metrics.get("numRegisteredTimers") === 1) - assert(metrics.get("timerProcessingTimeMs") < 2000) - }, + assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == + q.lastProgress.stateOperators.head.customMetrics.get("numValueStateVars").toInt) + assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == + q.lastProgress.stateOperators.head.customMetrics.get("numListStateVars").toInt) + assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == + q.lastProgress.stateOperators.head.customMetrics.get("numMapStateVars").toInt) - AddData(inputData, "b"), - AdvanceManualClock(1 * 1000), - // Side effect: timer scheduled for t = 2 + 10 = 12. - CheckNewAnswer(), - Execute { q => - val metrics = q.lastProgress.stateOperators(0).customMetrics - assert(metrics.get("numRegisteredTimers") === 1) - assert(metrics.get("timerProcessingTimeMs") < 2000) - }, + assert(colFamilySeq.length == 5) + assert(colFamilySeq.map(_.toString).toSet == Set( + schema0, schema1, schema2, schema3, schema4 + ).map(_.toString)) + }, + StopStream + ) + } + } + } + } - AddData(inputData, "c"), - // Time is currently 2 and we need to advance past 12. So, advance by 11 seconds. - AdvanceManualClock(11 * 1000), - CheckNewAnswer("a", "b"), - Execute { q => - val metrics = q.lastProgress.stateOperators(0).customMetrics - assert(metrics.get("numRegisteredTimers") === 1) + test("transformWithState - verify that OperatorStateMetadataV2" + + " file is being written correctly") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) - // Both timers should have fired and taken 1 second each to process. - assert(metrics.get("timerProcessingTimeMs") >= 2000) - }, + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream, + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "2")), + StopStream + ) - StopStream - ) - } + val df = spark.read.format("state-metadata").load(checkpointDir.toString) - test("Use statefulProcessor without transformWithState -" + - " handle should be absent") { - val processor = new RunningCountStatefulProcessor() - val ex = intercept[Exception] { - processor.getHandle + // check first 6 columns of the row, and then read the last column of the row separately + checkAnswer( + df.select( + "operatorId", "operatorName", "stateStoreName", "numPartitions", "minBatchId", + "maxBatchId"), + Seq(Row(0, "transformWithStateExec", "default", 5, 0, 1)) + ) + val operatorPropsJson = df.select("operatorProperties").collect().head.getString(0) + val operatorProperties = TransformWithStateOperatorProperties.fromJson(operatorPropsJson) + assert(operatorProperties.timeMode == "NoTime") + assert(operatorProperties.outputMode == "Update") + assert(operatorProperties.stateVariables.length == 1) + assert(operatorProperties.stateVariables.head.stateName == "countState") + } } - checkError( - ex.asInstanceOf[SparkRuntimeException], - condition = "STATE_STORE_HANDLE_NOT_INITIALIZED", - parameters = Map.empty - ) } - test("transformWithState - batch should succeed") { - val inputData = Seq("a", "b") - val result = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), - TimeMode.None(), - OutputMode.Append()) + test("test that different outputMode after query restart fails") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val result1 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) - val df = result.toDF() - checkAnswer(df, Seq(("a", "1"), ("b", "1")).toDF()) + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream + ) + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Append()) + + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + ExpectFailure[StateStoreInvalidConfigAfterRestart] { e => + checkError( + e.asInstanceOf[SparkUnsupportedOperationException], + condition = "STATE_STORE_INVALID_CONFIG_AFTER_RESTART", + parameters = Map( + "configName" -> "outputMode", + "oldConfig" -> "Update", + "newConfig" -> "Append") + ) + } + ) + } + } } - test("transformWithState - test deleteIfExists operator") { + test("test that changing between different state variable types fails") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - withTempDir { chkptDir => - val dirPath = chkptDir.getCanonicalPath - val inputData = MemoryStream[(String, String)] - val stream1 = inputData.toDS() - .groupByKey(x => x._1) - .transformWithState(new RunningCountMostRecentStatefulProcessor(), + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), TimeMode.None(), OutputMode.Update()) - val stream2 = inputData.toDS() - .groupByKey(x => x._1) - .transformWithState(new MostRecentStatefulProcessorWithDeletion(), + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream + ) + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountListStatefulProcessor(), TimeMode.None(), OutputMode.Update()) + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + ExpectFailure[StateStoreInvalidVariableTypeChange] { t => + checkError( + t.asInstanceOf[SparkUnsupportedOperationException], + condition = "STATE_STORE_INVALID_VARIABLE_TYPE_CHANGE", + parameters = Map( + "stateVarName" -> "countState", + "newType" -> "ListState", + "oldType" -> "ValueState") + ) + } + ) + } + } + } - testStream(stream1, OutputMode.Update())( - StartStream(checkpointLocation = dirPath), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "1", "")), + test("test that different timeMode after query restart fails") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val clock = new StreamManualClock + val inputData = MemoryStream[String] + val result1 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), StopStream ) - testStream(stream2, OutputMode.Update())( - StartStream(checkpointLocation = dirPath), - AddData(inputData, ("a", "str2"), ("b", "str3")), - CheckNewAnswer(("a", "str1"), - ("b", "")), // should not factor in previous count state - Execute { q => - assert(q.lastProgress.stateOperators(0).customMetrics.get("numValueStateVars") > 0) - assert(q.lastProgress.stateOperators(0).customMetrics.get("numDeletedStateVars") > 0) - }, - StopStream + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.ProcessingTime(), + OutputMode.Update()) + testStream(result2, OutputMode.Update())( + StartStream( + checkpointLocation = checkpointDir.getCanonicalPath, + trigger = Trigger.ProcessingTime("1 second"), + triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + ExpectFailure[StateStoreInvalidConfigAfterRestart] { e => + checkError( + e.asInstanceOf[SparkUnsupportedOperationException], + condition = "STATE_STORE_INVALID_CONFIG_AFTER_RESTART", + parameters = Map( + "configName" -> "timeMode", + "oldConfig" -> "NoTime", + "newConfig" -> "ProcessingTime") + ) + } ) } } } - test("transformWithState - two input streams") { + test("test query restart with new state variable succeeds") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - val inputData1 = MemoryStream[String] - val inputData2 = MemoryStream[String] + withTempDir { checkpointDir => + val clock = new StreamManualClock - val result = inputData1.toDS() - .union(inputData2.toDS()) - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), - TimeMode.None(), - OutputMode.Update()) + val inputData1 = MemoryStream[String] + val result1 = inputData1.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.ProcessingTime(), + OutputMode.Update()) - testStream(result, OutputMode.Update())( - AddData(inputData1, "a"), - CheckNewAnswer(("a", "1")), - AddData(inputData2, "a", "b"), - CheckNewAnswer(("a", "2"), ("b", "1")), - AddData(inputData1, "a", "b"), // should remove state for "a" and not return anything for a - CheckNewAnswer(("b", "2")), - AddData(inputData1, "d", "e"), - AddData(inputData2, "a", "c"), // should recreate state for "a" and return count as 1 - CheckNewAnswer(("a", "1"), ("c", "1"), ("d", "1"), ("e", "1")), - StopStream - ) + testStream(result1, OutputMode.Update())( + StartStream( + checkpointLocation = checkpointDir.getCanonicalPath, + trigger = Trigger.ProcessingTime("1 second"), + triggerClock = clock), + AddData(inputData1, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")), + StopStream + ) + + val result2 = inputData1.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorWithProcTimeTimer(), + TimeMode.ProcessingTime(), + OutputMode.Update()) + + testStream(result2, OutputMode.Update())( + StartStream( + checkpointLocation = checkpointDir.getCanonicalPath, + trigger = Trigger.ProcessingTime("1 second"), + triggerClock = clock), + AddData(inputData1, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "2")), + StopStream + ) + } } } - test("transformWithState - three input streams") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, + test("test exceeding schema file threshold throws error") { + withSQLConf( + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - val inputData1 = MemoryStream[String] - val inputData2 = MemoryStream[String] - val inputData3 = MemoryStream[String] + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, + SQLConf.STREAMING_MAX_NUM_STATE_SCHEMA_FILES.key -> 1.toString) { + withTempDir { dirPath => + val inputData = MemoryStream[(String, String)] + // First run with both count and mostRecent states + val stream1 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(stream1, OutputMode.Update())( + StartStream(checkpointLocation = dirPath.getCanonicalPath), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "1", "")), + StopStream + ) - // union 3 input streams - val result = inputData1.toDS() - .union(inputData2.toDS()) - .union(inputData3.toDS()) - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), - TimeMode.None(), - OutputMode.Update()) + // Second run deletes count state but keeps mostRecent + val stream2 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new MostRecentStatefulProcessorWithDeletion(), + TimeMode.None(), + OutputMode.Update()) - testStream(result, OutputMode.Update())( - AddData(inputData1, "a"), - CheckNewAnswer(("a", "1")), - AddData(inputData2, "a", "b"), - CheckNewAnswer(("a", "2"), ("b", "1")), - AddData(inputData3, "a", "b"), // should remove state for "a" and not return anything for a - CheckNewAnswer(("b", "2")), - AddData(inputData1, "d", "e"), - AddData(inputData2, "a", "c"), // should recreate state for "a" and return count as 1 - CheckNewAnswer(("a", "1"), ("c", "1"), ("d", "1"), ("e", "1")), - AddData(inputData3, "a", "c", "d", "e"), - CheckNewAnswer(("a", "2"), ("c", "2"), ("d", "2"), ("e", "2")), - StopStream - ) + testStream(stream2, OutputMode.Update())( + StartStream(checkpointLocation = dirPath.getCanonicalPath), + AddData(inputData, ("a", "str2"), ("b", "str3")), + ExpectFailure[StateStoreStateSchemaFilesThresholdExceeded] { t => + checkError( + t.asInstanceOf[StateStoreStateSchemaFilesThresholdExceeded], + condition = "STATE_STORE_STATE_SCHEMA_FILES_THRESHOLD_EXCEEDED", + parameters = Map( + "numStateSchemaFiles" -> "2", + "maxStateSchemaFiles" -> "1", + "removedColumnFamilies" -> "(countState)", + "addedColumnFamilies" -> "()" + ) + ) + } + ) + } } } - test("transformWithState - two input streams, different key type") { + test("test query restart succeeds") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - val inputData1 = MemoryStream[String] - val inputData2 = MemoryStream[Long] + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val result1 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) - val result = inputData1.toDS() - // union inputData2 by casting it to a String - .union(inputData2.toDS().map(_.toString)) - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), - TimeMode.None(), - OutputMode.Update()) + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream + ) + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorWithProcTimeTimer(), + TimeMode.None(), + OutputMode.Update()) - testStream(result, OutputMode.Update())( - AddData(inputData1, "1"), - CheckNewAnswer(("1", "1")), - AddData(inputData2, 1L, 2L), - CheckNewAnswer(("1", "2"), ("2", "1")), - AddData(inputData1, "1", "2"), // should remove state for "1" and not return anything. - CheckNewAnswer(("2", "2")), - AddData(inputData1, "4", "5"), - AddData(inputData2, 1L, 3L), // should recreate state for "1" and return count as 1 - CheckNewAnswer(("1", "1"), ("3", "1"), ("4", "1"), ("5", "1")), - StopStream - ) + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "2")), + StopStream + ) + } } } - /** Create a text file with a single data item */ - private def createFile(data: String, srcDir: File): File = - stringToFile(new File(srcDir, s"${UUID.randomUUID()}.txt"), data) - - private def createFileStream(srcDir: File): Dataset[(String, String)] = { - spark - .readStream - .option("maxFilesPerTrigger", "1") - .text(srcDir.getCanonicalPath) - .select("value").as[String] - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), - TimeMode.None(), - OutputMode.Update()) - } - - test("transformWithState - availableNow trigger mode, rate limit is respected") { + test("SPARK-49070: transformWithState - valid initial state plan") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) { withTempDir { srcDir => - Seq("a", "b", "c").foreach(createFile(_, srcDir)) - - // Set up a query to read text files one at a time val df = createFileStream(srcDir) - testStream(df)( - StartStream(trigger = Trigger.AvailableNow()), - ProcessAllAvailable(), - CheckNewAnswer(("a", "1"), ("b", "1"), ("c", "1")), - StopStream, - Execute { _ => - createFile("a", srcDir) - }, - StartStream(trigger = Trigger.AvailableNow()), - ProcessAllAvailable(), - CheckNewAnswer(("a", "2")) - ) - var index = 0 - val foreachBatchDf = df.writeStream + + val q = df.writeStream .foreachBatch((_: Dataset[(String, String)], _: Long) => { index += 1 }) - .trigger(Trigger.AvailableNow()) + .trigger(Trigger.AvailableNow) .start() try { - foreachBatchDf.awaitTermination() - assert(index == 4) - } finally { - foreachBatchDf.stop() - } - } - } - } - - test("transformWithState - availableNow trigger mode, multiple restarts") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName) { - withTempDir { srcDir => - Seq("a", "b", "c").foreach(createFile(_, srcDir)) - val df = createFileStream(srcDir) - - var index = 0 - - def startTriggerAvailableNowQueryAndCheck(expectedIdx: Int): Unit = { - val q = df.writeStream - .foreachBatch((_: Dataset[(String, String)], _: Long) => { - index += 1 - }) - .trigger(Trigger.AvailableNow) - .start() - try { - assert(q.awaitTermination(streamingTimeout.toMillis)) - assert(index == expectedIdx) - } finally { - q.stop() - } - } - // start query for the first time - startTriggerAvailableNowQueryAndCheck(3) - - // add two files and restart - createFile("a", srcDir) - createFile("b", srcDir) - startTriggerAvailableNowQueryAndCheck(8) - - // try restart again - createFile("d", srcDir) - startTriggerAvailableNowQueryAndCheck(14) - } - } - } - - Seq(false, true).foreach { useImplicits => - test("transformWithState - verify StateSchemaV3 writes " + - s"correct SQL schema of key/value with useImplicits=$useImplicits") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - withTempDir { checkpointDir => - // When Avro is used, we want to set the StructFields to nullable - val shouldBeNullable = usingAvroEncoding() - val metadataPathPostfix = "state/0/_stateSchema/default" - val stateSchemaPath = new Path(checkpointDir.toString, - s"$metadataPathPostfix") - val hadoopConf = spark.sessionState.newHadoopConf() - val fm = CheckpointFileManager.create(stateSchemaPath, hadoopConf) - - val keySchema = new StructType().add("value", StringType) - val schema0 = StateStoreColFamilySchema( - "countState", 0, - keySchema, 0, - new StructType().add("value", LongType, nullable = shouldBeNullable), - Some(NoPrefixKeyStateEncoderSpec(keySchema)), - None - ) - val schema1 = StateStoreColFamilySchema( - "listState", 0, - keySchema, 0, - new StructType() - .add("id", LongType, nullable = shouldBeNullable) - .add("name", StringType), - Some(NoPrefixKeyStateEncoderSpec(keySchema)), - None - ) - - val userKeySchema = new StructType() - .add("id", IntegerType, false) - .add("name", StringType) - val compositeKeySchema = new StructType() - .add("key", new StructType().add("value", StringType)) - .add("userKey", userKeySchema) - val schema2 = StateStoreColFamilySchema( - "mapState", 0, - compositeKeySchema, 0, - new StructType().add("value", StringType), - Some(PrefixKeyScanStateEncoderSpec(compositeKeySchema, 1)), - Option(userKeySchema) - ) - - val schema3 = StateStoreColFamilySchema( - "$rowCounter_listState", 0, - keySchema, 0, - new StructType().add("count", LongType, nullable = shouldBeNullable), - Some(NoPrefixKeyStateEncoderSpec(keySchema)), - None - ) - - val schema4 = StateStoreColFamilySchema( - "default", 0, - keySchema, 0, - new StructType().add("value", BinaryType), - Some(NoPrefixKeyStateEncoderSpec(keySchema)), - None - ) + assert(q.awaitTermination(streamingTimeout.toMillis)) - val inputData = MemoryStream[String] - val result = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new StatefulProcessorWithCompositeTypes(useImplicits), - TimeMode.None(), - OutputMode.Update()) + val sparkPlan = + q.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.executedPlan + val transformWithStateExec = sparkPlan.collect { + case p: TransformWithStateExec => p + }.head - testStream(result, OutputMode.Update())( - StartStream(checkpointLocation = checkpointDir.getCanonicalPath), - AddData(inputData, "a", "b"), - CheckNewAnswer(("a", "1"), ("b", "1")), - Execute { q => - q.lastProgress.runId - val schemaFilePath = fm.list(stateSchemaPath).toSeq.head.getPath - val providerId = StateStoreProviderId(StateStoreId( - checkpointDir.getCanonicalPath, 0, 0), q.lastProgress.runId) - val checker = new StateSchemaCompatibilityChecker(providerId, - hadoopConf, List(schemaFilePath)) - val colFamilySeq = checker.readSchemaFile() + assert(!transformWithStateExec.hasInitialState) - assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == - q.lastProgress.stateOperators.head.customMetrics.get("numValueStateVars").toInt) - assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == - q.lastProgress.stateOperators.head.customMetrics.get("numListStateVars").toInt) - assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == - q.lastProgress.stateOperators.head.customMetrics.get("numMapStateVars").toInt) + // EnsureRequirements should not apply on the initial state plan + val exchange = transformWithStateExec.initialState.collect { + case s: ShuffleExchangeExec => s + } - assert(colFamilySeq.length == 5) - assert(colFamilySeq.map(_.toString).toSet == Set( - schema0, schema1, schema2, schema3, schema4 - ).map(_.toString)) - }, - StopStream - ) + assert(exchange.isEmpty) + } finally { + q.stop() } } } } - test("transformWithState - verify that OperatorStateMetadataV2" + - " file is being written correctly") { + private def getFiles(path: Path): Array[FileStatus] = { + val hadoopConf = spark.sessionState.newHadoopConf() + val fileManager = CheckpointFileManager.create(path, hadoopConf) + fileManager.list(path) + } + + private def getStateSchemaPath(stateCheckpointPath: Path): Path = { + new Path(stateCheckpointPath, "_stateSchema/default/") + } + + // TODO: [SPARK-50845] Re-enable tests after StateSchemaV3 threshold change + ignore("transformWithState - verify that metadata and schema logs are purged") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - withTempDir { checkpointDir => - val inputData = MemoryStream[String] - val result = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, + SQLConf.MIN_BATCHES_TO_RETAIN.key -> "1") { + withTempDir { chkptDir => + val stateOpIdPath = new Path(new Path(chkptDir.getCanonicalPath, "state"), "0") + val stateSchemaPath = getStateSchemaPath(stateOpIdPath) + + val metadataPath = OperatorStateMetadataV2.metadataDirPath(stateOpIdPath) + // in this test case, we are changing the state spec back and forth + // to trigger the writing of the schema and metadata files + val inputData = MemoryStream[(String, String)] + val result1 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessor(), TimeMode.None(), OutputMode.Update()) - - testStream(result, OutputMode.Update())( - StartStream(checkpointLocation = checkpointDir.getCanonicalPath), - AddData(inputData, "a"), - CheckNewAnswer(("a", "1")), - StopStream, - StartStream(checkpointLocation = checkpointDir.getCanonicalPath), - AddData(inputData, "a"), - CheckNewAnswer(("a", "2")), + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "1", "")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, StopStream ) + val result2 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new MostRecentStatefulProcessorWithDeletion(), + TimeMode.None(), + OutputMode.Update()) + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str2")), + CheckNewAnswer(("a", "str1")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + // assert that a metadata and schema file has been written for each run + // as state variables have been deleted + assert(getFiles(metadataPath).length == 2) + assert(getFiles(stateSchemaPath).length == 2) - val df = spark.read.format("state-metadata").load(checkpointDir.toString) - - // check first 6 columns of the row, and then read the last column of the row separately - checkAnswer( - df.select( - "operatorId", "operatorName", "stateStoreName", "numPartitions", "minBatchId", - "maxBatchId"), - Seq(Row(0, "transformWithStateExec", "default", 5, 0, 1)) + val result3 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + testStream(result3, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str3")), + CheckNewAnswer(("a", "1", "str2")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream ) - val operatorPropsJson = df.select("operatorProperties").collect().head.getString(0) - val operatorProperties = TransformWithStateOperatorProperties.fromJson(operatorPropsJson) - assert(operatorProperties.timeMode == "NoTime") - assert(operatorProperties.outputMode == "Update") - assert(operatorProperties.stateVariables.length == 1) - assert(operatorProperties.stateVariables.head.stateName == "countState") + // because we don't change the schema for this run, there won't + // be a new schema file written. + testStream(result3, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str4")), + CheckNewAnswer(("a", "2", "str3")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + // by the end of the test, there have been 4 batches, + // so the metadata and schema logs, and commitLog has been purged + // for batches 0 and 1 so metadata and schema files exist for batches 0, 1, 2, 3 + // and we only need to keep metadata files for batches 2, 3, and the since schema + // hasn't changed between batches 2, 3, we only keep the schema file for batch 2 + assert(getFiles(metadataPath).length == 2) + assert(getFiles(stateSchemaPath).length == 1) } } } - testWithEncoding("unsaferow")("test that invalid schema evolution " + - "fails query for column family") { + // TODO: [SPARK-50845] Re-enable tests after StateSchemaV3 threshold change + ignore("transformWithState - verify that schema file " + + "is kept after metadata is purged") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - withTempDir { checkpointDir => - val inputData = MemoryStream[String] + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, + SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2") { + withTempDir { chkptDir => + val stateOpIdPath = new Path(new Path(chkptDir.getCanonicalPath, "state"), "0") + val stateSchemaPath = getStateSchemaPath(stateOpIdPath) + + val metadataPath = OperatorStateMetadataV2.metadataDirPath(stateOpIdPath) + // in this test case, we are changing the state spec back and forth + // to trigger the writing of the schema and metadata files + val inputData = MemoryStream[(String, String)] val result1 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessor(), TimeMode.None(), OutputMode.Update()) - testStream(result1, OutputMode.Update())( - StartStream(checkpointLocation = checkpointDir.getCanonicalPath), - AddData(inputData, "a"), - CheckNewAnswer(("a", "1")), + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "1", "")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "2", "str1")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, StopStream ) val result2 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessorInt(), + .groupByKey(x => x._1) + .transformWithState(new MostRecentStatefulProcessorWithDeletion(), TimeMode.None(), OutputMode.Update()) - testStream(result2, OutputMode.Update())( - StartStream(checkpointLocation = checkpointDir.getCanonicalPath), - AddData(inputData, "a"), - ExpectFailure[StateStoreValueSchemaNotCompatible] { - (t: Throwable) => { - assert(t.getMessage.contains("Please check number and type of fields.")) + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str2")), + CheckNewAnswer(("a", "str1")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) } - } + }, + StopStream + ) + assert(getFiles(metadataPath).length == 3) + assert(getFiles(stateSchemaPath).length == 2) + + val result3 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + testStream(result3, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str3")), + CheckNewAnswer(("a", "1", "str2")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream ) + // metadata files should be kept for batches 1, 2, 3 + // schema files should be kept for batches 0, 2, 3 + assert(getFiles(metadataPath).length == 3) + assert(getFiles(stateSchemaPath).length == 3) + // we want to ensure that we can read batch 1 even though the + // metadata file for batch 0 was removed + val batch1Df = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, chkptDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "countState") + .option(StateSourceOptions.BATCH_ID, 1) + .load() + + val batch1AnsDf = batch1Df.selectExpr( + "key.value AS groupingKey", + "value.value AS valueId") + + checkAnswer(batch1AnsDf, Seq(Row("a", 2L))) + + val batch3Df = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, chkptDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "countState") + .option(StateSourceOptions.BATCH_ID, 3) + .load() + + val batch3AnsDf = batch3Df.selectExpr( + "key.value AS groupingKey", + "value.value AS valueId") + checkAnswer(batch3AnsDf, Seq(Row("a", 1L))) } } } - testWithEncoding("avro")("test that invalid schema evolution " + - "fails query for column family") { + test("state data source integration - value state supports time travel") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - withTempDir { checkpointDir => - val inputData = MemoryStream[String] + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, + SQLConf.MIN_BATCHES_TO_RETAIN.key -> "5") { + withTempDir { chkptDir => + // in this test case, we are changing the state spec back and forth + // to trigger the writing of the schema and metadata files + val inputData = MemoryStream[(String, String)] val result1 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessor(), TimeMode.None(), OutputMode.Update()) - testStream(result1, OutputMode.Update())( - StartStream(checkpointLocation = checkpointDir.getCanonicalPath), - AddData(inputData, "a"), - CheckNewAnswer(("a", "1")), + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "1", "")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "2", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "3", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "4", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "5", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "6", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "7", "str1")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, StopStream ) val result2 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessorInt(), + .groupByKey(x => x._1) + .transformWithState(new MostRecentStatefulProcessorWithDeletion(), TimeMode.None(), OutputMode.Update()) - testStream(result2, OutputMode.Update())( - StartStream(checkpointLocation = checkpointDir.getCanonicalPath), - AddData(inputData, "a"), - ExpectFailure[StateStoreInvalidValueSchemaEvolution] { e => - checkError( - e.asInstanceOf[SparkUnsupportedOperationException], - condition = "STATE_STORE_INVALID_VALUE_SCHEMA_EVOLUTION", - parameters = Map( - "oldValueSchema" -> "StructType(StructField(value,LongType,true))", - "newValueSchema" -> "StructType(StructField(value,IntegerType,true))") - ) - } + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str2")), + CheckNewAnswer(("a", "str1")), + AddData(inputData, ("a", "str3")), + CheckNewAnswer(("a", "str2")), + AddData(inputData, ("a", "str4")), + CheckNewAnswer(("a", "str3")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream ) - } - } - } - test("test that different outputMode after query restart fails") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - withTempDir { checkpointDir => - val inputData = MemoryStream[String] - val result1 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), - TimeMode.None(), - OutputMode.Update()) + // Batches 0-7: countState, mostRecent + // Batches 8-9: countState - testStream(result1, OutputMode.Update())( - StartStream(checkpointLocation = checkpointDir.getCanonicalPath), - AddData(inputData, "a"), - CheckNewAnswer(("a", "1")), - StopStream - ) - val result2 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), - TimeMode.None(), - OutputMode.Append()) + // By this time, offset and commit logs for batches 0-3 have been purged. + // However, if we want to read the data for batch 4, we need to read the corresponding + // metadata and schema file for batch 4, or the latest files that correspond to + // batch 4 (in this case, the files were written for batch 0). + // We want to test the behavior where the metadata files are preserved so that we can + // read from the state data source, even if the commit and offset logs are purged for + // a particular batch - testStream(result2, OutputMode.Update())( - StartStream(checkpointLocation = checkpointDir.getCanonicalPath), - AddData(inputData, "a"), - ExpectFailure[StateStoreInvalidConfigAfterRestart] { e => - checkError( - e.asInstanceOf[SparkUnsupportedOperationException], - condition = "STATE_STORE_INVALID_CONFIG_AFTER_RESTART", - parameters = Map( - "configName" -> "outputMode", - "oldConfig" -> "Update", - "newConfig" -> "Append") - ) - } + val df = spark.read.format("state-metadata").load(chkptDir.toString) + + // check the min and max batch ids that we have data for + checkAnswer( + df.select( + "operatorId", "operatorName", "stateStoreName", "numPartitions", "minBatchId", + "maxBatchId"), + Seq(Row(0, "transformWithStateExec", "default", 5, 4, 9)) ) + + val countStateDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, chkptDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "countState") + .option(StateSourceOptions.BATCH_ID, 4) + .load() + + val countStateAnsDf = countStateDf.selectExpr( + "key.value AS groupingKey", + "value.value AS valueId") + checkAnswer(countStateAnsDf, Seq(Row("a", 5L))) + + val mostRecentDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, chkptDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "mostRecent") + .option(StateSourceOptions.BATCH_ID, 4) + .load() + + val mostRecentAnsDf = mostRecentDf.selectExpr( + "key.value AS groupingKey", + "value.value") + checkAnswer(mostRecentAnsDf, Seq(Row("a", "str1"))) } } } - test("test that changing between different state variable types fails") { + // TODO: [SPARK-50845] Re-enable tests after StateSchemaV3 threshold change + ignore("transformWithState - verify that all metadata and schema logs are not purged") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - withTempDir { checkpointDir => - val inputData = MemoryStream[String] - val result = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, + SQLConf.MIN_BATCHES_TO_RETAIN.key -> "3") { + withTempDir { chkptDir => + val inputData = MemoryStream[(String, String)] + val result1 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessor(), TimeMode.None(), OutputMode.Update()) - - testStream(result, OutputMode.Update())( - StartStream(checkpointLocation = checkpointDir.getCanonicalPath), - AddData(inputData, "a"), - CheckNewAnswer(("a", "1")), + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "1", "")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "2", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "3", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "4", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "5", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "6", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "7", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "8", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "9", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "10", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "11", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "12", "str1")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, StopStream ) - val result2 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountListStatefulProcessor(), - TimeMode.None(), - OutputMode.Update()) - testStream(result2, OutputMode.Update())( - StartStream(checkpointLocation = checkpointDir.getCanonicalPath), - AddData(inputData, "a"), - ExpectFailure[StateStoreInvalidVariableTypeChange] { t => - checkError( - t.asInstanceOf[SparkUnsupportedOperationException], - condition = "STATE_STORE_INVALID_VARIABLE_TYPE_CHANGE", - parameters = Map( - "stateVarName" -> "countState", - "newType" -> "ListState", - "oldType" -> "ValueState") - ) - } + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "13", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "14", "str1")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream ) + + val stateOpIdPath = new Path(new Path(chkptDir.getCanonicalPath, "state"), "0") + val stateSchemaPath = getStateSchemaPath(stateOpIdPath) + + val metadataPath = OperatorStateMetadataV2.metadataDirPath(stateOpIdPath) + + // Metadata files exist for batches 0, 12, and the thresholdBatchId is 8 + // as this is the earliest batchId for which the commit log is not present, + // so we need to keep metadata files for batch 0 so we can read the commit + // log correspondingly + assert(getFiles(metadataPath).length == 2) + assert(getFiles(stateSchemaPath).length == 1) } } } - test("test that different timeMode after query restart fails") { + // TODO: [SPARK-50845] Re-enable tests after StateSchemaV3 threshold change + ignore("transformWithState - verify that no metadata and schema logs are purged after" + + " removing column family") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - withTempDir { checkpointDir => - val clock = new StreamManualClock - val inputData = MemoryStream[String] + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, + SQLConf.MIN_BATCHES_TO_RETAIN.key -> "3") { + withTempDir { chkptDir => + val inputData = MemoryStream[(String, String)] val result1 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessor(), TimeMode.None(), OutputMode.Update()) - testStream(result1, OutputMode.Update())( - StartStream(checkpointLocation = checkpointDir.getCanonicalPath), - AddData(inputData, "a"), - CheckNewAnswer(("a", "1")), + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "1", "")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "2", "str1")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "1", "")), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "2", "str1")), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "3", "str1")), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "4", "str1")), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "5", "str1")), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "6", "str1")), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "7", "str1")), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "8", "str1")), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "9", "str1")), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "10", "str1")), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "11", "str1")), + AddData(inputData, ("b", "str1")), + CheckNewAnswer(("b", "12", "str1")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, StopStream ) val result2 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), - TimeMode.ProcessingTime(), + .groupByKey(x => x._1) + .transformWithState(new MostRecentStatefulProcessorWithDeletion(), + TimeMode.None(), OutputMode.Update()) + testStream(result2, OutputMode.Update())( - StartStream( - checkpointLocation = checkpointDir.getCanonicalPath, - trigger = Trigger.ProcessingTime("1 second"), - triggerClock = clock), - AddData(inputData, "a"), - AdvanceManualClock(1 * 1000), - ExpectFailure[StateStoreInvalidConfigAfterRestart] { e => - checkError( - e.asInstanceOf[SparkUnsupportedOperationException], - condition = "STATE_STORE_INVALID_CONFIG_AFTER_RESTART", - parameters = Map( - "configName" -> "timeMode", - "oldConfig" -> "NoTime", - "newConfig" -> "ProcessingTime") - ) - } + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("b", "str2")), + CheckNewAnswer(("b", "str1")), + AddData(inputData, ("b", "str3")), + CheckNewAnswer(("b", "str2")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream ) + + val stateOpIdPath = new Path(new Path(chkptDir.getCanonicalPath, "state"), "0") + val stateSchemaPath = getStateSchemaPath(stateOpIdPath) + + val metadataPath = OperatorStateMetadataV2.metadataDirPath(stateOpIdPath) + + // Metadata files are written for batches 0, 2, and 14. + // Schema files are written for 0, 14 + // At the beginning of the last query run, the thresholdBatchId is 11. + // However, we would need both schema files to be preserved, if we want to + // be able to read from batch 11 onwards. + assert(getFiles(metadataPath).length == 2) + assert(getFiles(stateSchemaPath).length == 2) } } } +} - testWithEncoding("unsaferow")("test that introducing TTL after restart fails query") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - withTempDir { checkpointDir => - val inputData = MemoryStream[String] - val clock = new StreamManualClock - val result = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), - TimeMode.ProcessingTime(), - OutputMode.Update()) +class TransformWithStateValidationSuite extends StateStoreMetricsTest { + import testImplicits._ - testStream(result, OutputMode.Update())( - StartStream( - trigger = Trigger.ProcessingTime("1 second"), - checkpointLocation = checkpointDir.getCanonicalPath, - triggerClock = clock), - AddData(inputData, "a"), - AdvanceManualClock(1 * 1000), - CheckNewAnswer(("a", "1")), - AdvanceManualClock(1 * 1000), - StopStream - ) - val result2 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessorWithTTL(), - TimeMode.ProcessingTime(), - OutputMode.Update()) - testStream(result2, OutputMode.Update())( - StartStream( - trigger = Trigger.ProcessingTime("1 second"), - checkpointLocation = checkpointDir.getCanonicalPath, - triggerClock = clock), - AddData(inputData, "a"), - AdvanceManualClock(1 * 1000), - ExpectFailure[StateStoreValueSchemaNotCompatible] { t => - checkError( - t.asInstanceOf[SparkUnsupportedOperationException], - condition = "STATE_STORE_VALUE_SCHEMA_NOT_COMPATIBLE", - parameters = Map( - "storedValueSchema" -> "StructType(StructField(value,LongType,false))", - "newValueSchema" -> - ("StructType(StructField(value,StructType(StructField(value,LongType,false))," + - "true),StructField(ttlExpirationMs,LongType,true))") - ) - ) - } - ) + test("transformWithState - streaming with hdfsStateStoreProvider should fail") { + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + AddData(inputData, "a"), + ExpectFailure[StateStoreMultipleColumnFamiliesNotSupportedException] { t => + assert(t.getMessage.contains("not supported")) } - } + ) } - testWithEncoding("avro")("test that introducing TTL after restart fails query") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - withTempDir { checkpointDir => - val inputData = MemoryStream[String] - val clock = new StreamManualClock - val result = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), - TimeMode.ProcessingTime(), - OutputMode.Update()) - - testStream(result, OutputMode.Update())( - StartStream( - trigger = Trigger.ProcessingTime("1 second"), - checkpointLocation = checkpointDir.getCanonicalPath, - triggerClock = clock), - AddData(inputData, "a"), - AdvanceManualClock(1 * 1000), - CheckNewAnswer(("a", "1")), - AdvanceManualClock(1 * 1000), - StopStream - ) - val result2 = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessorWithTTL(), - TimeMode.ProcessingTime(), - OutputMode.Update()) - testStream(result2, OutputMode.Update())( - StartStream( - trigger = Trigger.ProcessingTime("1 second"), - checkpointLocation = checkpointDir.getCanonicalPath, - triggerClock = clock), - AddData(inputData, "a"), - AdvanceManualClock(1 * 1000), - ExpectFailure[StateStoreInvalidValueSchemaEvolution] { e => - checkError( - e.asInstanceOf[SparkUnsupportedOperationException], - condition = "STATE_STORE_INVALID_VALUE_SCHEMA_EVOLUTION", - parameters = Map( - "newValueSchema" -> ("StructType(StructField(value,StructType(StructField(" + - "value,LongType,true)),true),StructField(ttlExpirationMs,LongType,true))"), - "oldValueSchema" -> "StructType(StructField(value,LongType,true))") - ) - } - ) + test("transformWithStateWithInitialState - streaming with hdfsStateStoreProvider should fail") { + val inputData = MemoryStream[InitInputRow] + val initDf = Seq(("init_1", 40.0), ("init_2", 100.0)).toDS() + .groupByKey(x => x._1) + .mapValues(x => x) + val result = inputData.toDS() + .groupByKey(x => x.key) + .transformWithState(new AccumulateStatefulProcessorWithInitState(), + TimeMode.None(), OutputMode.Append(), initDf + ) + testStream(result, OutputMode.Update())( + AddData(inputData, InitInputRow("a", "add", -1.0)), + ExpectFailure[StateStoreMultipleColumnFamiliesNotSupportedException] { + (t: Throwable) => { + assert(t.getMessage.contains("not supported")) + } } - } + ) } - test("test query restart with new state variable succeeds") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - withTempDir { checkpointDir => - val clock = new StreamManualClock + test("transformWithState - validate timeModes") { + // validation tests should pass for TimeMode.None + TransformWithStateVariableUtils.validateTimeMode(TimeMode.None(), None) + TransformWithStateVariableUtils.validateTimeMode(TimeMode.None(), Some(10L)) + + // validation tests should fail for TimeMode.ProcessingTime and TimeMode.EventTime + // when time values are not provided + val ex = intercept[SparkException] { + TransformWithStateVariableUtils.validateTimeMode(TimeMode.ProcessingTime(), None) + } + assert(ex.getMessage.contains("Failed to find time values")) + TransformWithStateVariableUtils.validateTimeMode(TimeMode.ProcessingTime(), Some(10L)) - val inputData1 = MemoryStream[String] - val result1 = inputData1.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), - TimeMode.ProcessingTime(), - OutputMode.Update()) + val ex1 = intercept[SparkException] { + TransformWithStateVariableUtils.validateTimeMode(TimeMode.EventTime(), None) + } + assert(ex1.getMessage.contains("Failed to find time values")) + TransformWithStateVariableUtils.validateTimeMode(TimeMode.EventTime(), Some(10L)) + } +} - testStream(result1, OutputMode.Update())( - StartStream( - checkpointLocation = checkpointDir.getCanonicalPath, - trigger = Trigger.ProcessingTime("1 second"), - triggerClock = clock), - AddData(inputData1, "a"), - AdvanceManualClock(1 * 1000), - CheckNewAnswer(("a", "1")), - StopStream - ) +class TransformWithStateAvroEncodingSuite extends TransformWithStateTest { - val result2 = inputData1.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessorWithProcTimeTimer(), - TimeMode.ProcessingTime(), - OutputMode.Update()) + import testImplicits._ - testStream(result2, OutputMode.Update())( - StartStream( - checkpointLocation = checkpointDir.getCanonicalPath, - trigger = Trigger.ProcessingTime("1 second"), - triggerClock = clock), - AddData(inputData1, "a"), - AdvanceManualClock(1 * 1000), - CheckNewAnswer(("a", "2")), - StopStream - ) + override protected def test(testName: String, testTags: Tag*)(testBody: => Any) + (implicit pos: Position): Unit = { + super.test(s"$testName (encoding = Avro)", testTags: _*) { + withSQLConf(SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> "avro") { + testBody } } } - test("test exceeding schema file threshold throws error") { - withSQLConf( - SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, + test("transformWithState - value schema threshold exceeded") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, - SQLConf.STREAMING_MAX_NUM_STATE_SCHEMA_FILES.key -> 1.toString) { - withTempDir { dirPath => - val inputData = MemoryStream[(String, String)] - // First run with both count and mostRecent states - val stream1 = inputData.toDS() - .groupByKey(x => x._1) - .transformWithState(new RunningCountMostRecentStatefulProcessor(), + SQLConf.STREAMING_VALUE_STATE_SCHEMA_EVOLUTION_THRESHOLD.key -> "0") { + withTempDir { chkptDir => + val dirPath = chkptDir.getCanonicalPath + val inputData = MemoryStream[String] + val result1 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorInt(), TimeMode.None(), OutputMode.Update()) - testStream(stream1, OutputMode.Update())( - StartStream(checkpointLocation = dirPath.getCanonicalPath), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "1", "")), - StopStream + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = dirPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + Execute { q => + assert(q.lastProgress.stateOperators(0).customMetrics.get("numValueStateVars") > 0) + assert(q.lastProgress.stateOperators(0).customMetrics.get("numRegisteredTimers") == 0) + assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) + }, + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "2"), ("b", "1")), + StopStream, + StartStream(checkpointLocation = dirPath), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckNewAnswer(("b", "2")), + StopStream, + Execute { q => + assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) + assert(q.lastProgress.stateOperators(0).numRowsRemoved === 1) + }, + StartStream(checkpointLocation = dirPath), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckNewAnswer(("a", "1"), ("c", "1")) ) - // Second run deletes count state but keeps mostRecent - val stream2 = inputData.toDS() - .groupByKey(x => x._1) - .transformWithState(new MostRecentStatefulProcessorWithDeletion(), + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), TimeMode.None(), OutputMode.Update()) - testStream(stream2, OutputMode.Update())( - StartStream(checkpointLocation = dirPath.getCanonicalPath), - AddData(inputData, ("a", "str2"), ("b", "str3")), - ExpectFailure[StateStoreStateSchemaFilesThresholdExceeded] { t => + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = dirPath), + AddData(inputData, "a"), + ExpectFailure[StateStoreValueSchemaEvolutionThresholdExceeded] { t => checkError( - t.asInstanceOf[StateStoreStateSchemaFilesThresholdExceeded], - condition = "STATE_STORE_STATE_SCHEMA_FILES_THRESHOLD_EXCEEDED", + t.asInstanceOf[StateStoreValueSchemaEvolutionThresholdExceeded], + condition = "STATE_STORE_VALUE_SCHEMA_EVOLUTION_THRESHOLD_EXCEEDED", parameters = Map( - "numStateSchemaFiles" -> "2", - "maxStateSchemaFiles" -> "1", - "removedColumnFamilies" -> "(countState)", - "addedColumnFamilies" -> "()" + "numSchemaEvolutions" -> "1", + "maxSchemaEvolutions" -> "0", + "colFamilyName" -> "countState" ) ) } @@ -2368,576 +2429,498 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("test query restart succeeds") { + test("transformWithState - upcasting should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - withTempDir { checkpointDir => + withTempDir { chkptDir => + val dirPath = chkptDir.getCanonicalPath val inputData = MemoryStream[String] val result1 = inputData.toDS() .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), + .transformWithState(new RunningCountStatefulProcessorInt(), TimeMode.None(), OutputMode.Update()) testStream(result1, OutputMode.Update())( - StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + StartStream(checkpointLocation = dirPath), AddData(inputData, "a"), CheckNewAnswer(("a", "1")), - StopStream + Execute { q => + assert(q.lastProgress.stateOperators(0).customMetrics.get("numValueStateVars") > 0) + assert(q.lastProgress.stateOperators(0).customMetrics.get("numRegisteredTimers") == 0) + assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) + }, + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "2"), ("b", "1")), + StopStream, + StartStream(checkpointLocation = dirPath), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckNewAnswer(("b", "2")), + StopStream, + Execute { q => + assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) + assert(q.lastProgress.stateOperators(0).numRowsRemoved === 1) + }, + StartStream(checkpointLocation = dirPath), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckNewAnswer(("a", "1"), ("c", "1")) ) + val result2 = inputData.toDS() .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessorWithProcTimeTimer(), + .transformWithState(new RunningCountStatefulProcessor(), TimeMode.None(), OutputMode.Update()) testStream(result2, OutputMode.Update())( - StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + StartStream(checkpointLocation = dirPath), AddData(inputData, "a"), CheckNewAnswer(("a", "2")), + AddData(inputData, "d"), + CheckNewAnswer(("d", "1")), StopStream ) } } } - test("SPARK-49070: transformWithState - valid initial state plan") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName) { - withTempDir { srcDir => - Seq("a", "b", "c").foreach(createFile(_, srcDir)) - val df = createFileStream(srcDir) - - var index = 0 - - val q = df.writeStream - .foreachBatch((_: Dataset[(String, String)], _: Long) => { - index += 1 - }) - .trigger(Trigger.AvailableNow) - .start() - - try { - assert(q.awaitTermination(streamingTimeout.toMillis)) + test("transformWithState - reordering fields should succeed") { + withSQLConf( + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { chkptDir => + val dirPath = chkptDir.getCanonicalPath + val inputData = MemoryStream[String] - val sparkPlan = - q.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.executedPlan - val transformWithStateExec = sparkPlan.collect { - case p: TransformWithStateExec => p - }.head + // First run with initial field order + val result1 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorInitialOrder(), + TimeMode.None(), + OutputMode.Update()) - assert(!transformWithStateExec.hasInitialState) + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = dirPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream + ) - // EnsureRequirements should not apply on the initial state plan - val exchange = transformWithStateExec.initialState.collect { - case s: ShuffleExchangeExec => s - } + // Second run with reordered fields + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorReorderedFields(), + TimeMode.None(), + OutputMode.Update()) - assert(exchange.isEmpty) - } finally { - q.stop() - } + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = dirPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "2")), // Should continue counting from previous state + StopStream + ) } } } - private def getFiles(path: Path): Array[FileStatus] = { - val hadoopConf = spark.sessionState.newHadoopConf() - val fileManager = CheckpointFileManager.create(path, hadoopConf) - fileManager.list(path) - } - - private def getStateSchemaPath(stateCheckpointPath: Path): Path = { - new Path(stateCheckpointPath, "_stateSchema/default/") - } - - // TODO: [SPARK-50845] Re-enable tests after StateSchemaV3 threshold change - ignore("transformWithState - verify that metadata and schema logs are purged") { + test("transformWithState - adding field should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, - SQLConf.MIN_BATCHES_TO_RETAIN.key -> "1") { + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { withTempDir { chkptDir => - val stateOpIdPath = new Path(new Path(chkptDir.getCanonicalPath, "state"), "0") - val stateSchemaPath = getStateSchemaPath(stateOpIdPath) - - val metadataPath = OperatorStateMetadataV2.metadataDirPath(stateOpIdPath) - // in this test case, we are changing the state spec back and forth - // to trigger the writing of the schema and metadata files - val inputData = MemoryStream[(String, String)] + val dirPath = chkptDir.getCanonicalPath + val inputData = MemoryStream[String] val result1 = inputData.toDS() - .groupByKey(x => x._1) - .transformWithState(new RunningCountMostRecentStatefulProcessor(), + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), TimeMode.None(), OutputMode.Update()) + testStream(result1, OutputMode.Update())( - StartStream(checkpointLocation = chkptDir.getCanonicalPath), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "1", "")), + StartStream(checkpointLocation = dirPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + Execute { q => + assert(q.lastProgress.stateOperators(0).customMetrics.get("numValueStateVars") > 0) + assert(q.lastProgress.stateOperators(0).customMetrics.get("numRegisteredTimers") == 0) + assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) + }, + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "2"), ("b", "1")), + StopStream, + StartStream(checkpointLocation = dirPath), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckNewAnswer(("b", "2")), + StopStream, Execute { q => - eventually(timeout(Span(5, Seconds))) { - q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) - } + assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) + assert(q.lastProgress.stateOperators(0).numRowsRemoved === 1) }, - StopStream + StartStream(checkpointLocation = dirPath), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckNewAnswer(("a", "1"), ("c", "1")) ) + val result2 = inputData.toDS() - .groupByKey(x => x._1) - .transformWithState(new MostRecentStatefulProcessorWithDeletion(), + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorNestedLongs(), TimeMode.None(), OutputMode.Update()) + testStream(result2, OutputMode.Update())( - StartStream(checkpointLocation = chkptDir.getCanonicalPath), - AddData(inputData, ("a", "str2")), - CheckNewAnswer(("a", "str1")), - Execute { q => - eventually(timeout(Span(5, Seconds))) { - q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) - } - }, + StartStream(checkpointLocation = dirPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "2")), StopStream ) - // assert that a metadata and schema file has been written for each run - // as state variables have been deleted - assert(getFiles(metadataPath).length == 2) - assert(getFiles(stateSchemaPath).length == 2) + } + } + } - val result3 = inputData.toDS() - .groupByKey(x => x._1) - .transformWithState(new RunningCountMostRecentStatefulProcessor(), + test("transformWithState - add and remove field between runs") { + withSQLConf( + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + withTempDir { dir => + val inputData = MemoryStream[String] + + // First run with original field names + val result1 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorInitialOrder(), TimeMode.None(), OutputMode.Update()) - testStream(result3, OutputMode.Update())( - StartStream(checkpointLocation = chkptDir.getCanonicalPath), - AddData(inputData, ("a", "str3")), - CheckNewAnswer(("a", "1", "str2")), - Execute { q => - eventually(timeout(Span(5, Seconds))) { - q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) - } - }, + + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = dir.getCanonicalPath), + AddData(inputData, "test1"), + CheckNewAnswer(("test1", "1")), StopStream ) - // because we don't change the schema for this run, there won't - // be a new schema file written. - testStream(result3, OutputMode.Update())( - StartStream(checkpointLocation = chkptDir.getCanonicalPath), - AddData(inputData, ("a", "str4")), - CheckNewAnswer(("a", "2", "str3")), - Execute { q => - eventually(timeout(Span(5, Seconds))) { - q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) - } - }, + + // Second run with renamed field + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RenameEvolvedProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = dir.getCanonicalPath), + // Uses default value, does not factor previous value1 into this + AddData(inputData, "test1"), + CheckNewAnswer(("test1", "1")), + // Verify we can write state with new field name + AddData(inputData, "test2"), + CheckNewAnswer(("test2", "1")), StopStream ) - // by the end of the test, there have been 4 batches, - // so the metadata and schema logs, and commitLog has been purged - // for batches 0 and 1 so metadata and schema files exist for batches 0, 1, 2, 3 - // and we only need to keep metadata files for batches 2, 3, and the since schema - // hasn't changed between batches 2, 3, we only keep the schema file for batch 2 - assert(getFiles(metadataPath).length == 2) - assert(getFiles(stateSchemaPath).length == 1) } } } - // TODO: [SPARK-50845] Re-enable tests after StateSchemaV3 threshold change - ignore("transformWithState - verify that schema file " + - "is kept after metadata is purged") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, - SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2") { + test("state data source - schema evolution with time travel support") { + withSQLConf( + rocksdbChangelogCheckpointingConfKey -> "true", + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1") { + withTempDir { chkptDir => - val stateOpIdPath = new Path(new Path(chkptDir.getCanonicalPath, "state"), "0") - val stateSchemaPath = getStateSchemaPath(stateOpIdPath) + val dirPath = chkptDir.getCanonicalPath + val inputData = MemoryStream[String] - val metadataPath = OperatorStateMetadataV2.metadataDirPath(stateOpIdPath) - // in this test case, we are changing the state spec back and forth - // to trigger the writing of the schema and metadata files - val inputData = MemoryStream[(String, String)] val result1 = inputData.toDS() - .groupByKey(x => x._1) - .transformWithState(new RunningCountMostRecentStatefulProcessor(), + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorTwoLongs(), TimeMode.None(), OutputMode.Update()) + testStream(result1, OutputMode.Update())( - StartStream(checkpointLocation = chkptDir.getCanonicalPath), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "1", "")), - Execute { q => - eventually(timeout(Span(5, Seconds))) { - q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) - } - }, - StopStream - ) - testStream(result1, OutputMode.Update())( - StartStream(checkpointLocation = chkptDir.getCanonicalPath), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "2", "str1")), - Execute { q => - eventually(timeout(Span(5, Seconds))) { - q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) - } - }, + StartStream(checkpointLocation = dirPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + AddData(inputData, "b"), + CheckNewAnswer(("b", "1")), + ProcessAllAvailable(), + Execute { _ => Thread.sleep(5000) }, StopStream ) + val result2 = inputData.toDS() - .groupByKey(x => x._1) - .transformWithState(new MostRecentStatefulProcessorWithDeletion(), + .groupByKey(x => x) + .transformWithState(new RenameEvolvedProcessor(), TimeMode.None(), OutputMode.Update()) + testStream(result2, OutputMode.Update())( - StartStream(checkpointLocation = chkptDir.getCanonicalPath), - AddData(inputData, ("a", "str2")), - CheckNewAnswer(("a", "str1")), - Execute { q => - eventually(timeout(Span(5, Seconds))) { - q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) - } - }, + StartStream(checkpointLocation = dirPath), + AddData(inputData, "c"), + CheckNewAnswer(("c", "1")), + AddData(inputData, "d"), + CheckNewAnswer(("d", "1")), + ProcessAllAvailable(), + Execute { _ => Thread.sleep(5000) }, StopStream ) - assert(getFiles(metadataPath).length == 3) - assert(getFiles(stateSchemaPath).length == 2) - val result3 = inputData.toDS() - .groupByKey(x => x._1) - .transformWithState(new RunningCountMostRecentStatefulProcessor(), + val oldStateDf = spark.read + .format("statestore") + .option("snapshotStartBatchId", 0) + .option("batchId", 1) + .option("snapshotPartitionId", 0) + .option(StateSourceOptions.STATE_VAR_NAME, "countState") + .load(dirPath) + + checkAnswer( + oldStateDf.selectExpr( + "key.value AS groupingKey", + "value.value1 AS count"), + Seq(Row("a", 1), Row("b", 1)) + ) + + val evolvedStateDf1 = spark.read + .format("statestore") + .option("snapshotStartBatchId", 0) + .option("batchId", 3) + .option("snapshotPartitionId", 0) + .option(StateSourceOptions.STATE_VAR_NAME, "countState") + .load(dirPath) + + checkAnswer( + evolvedStateDf1.selectExpr( + "key.value AS groupingKey", + "value.value4 AS count"), + Seq( + Row("a", null), + Row("b", null), + Row("c", 1), + Row("d", 1) + ) + ) + + val evolvedStateDf = spark.read + .format("statestore") + .option("snapshotStartBatchId", 3) + .option("snapshotPartitionId", 0) + .option(StateSourceOptions.STATE_VAR_NAME, "countState") + .load(dirPath) + + checkAnswer( + evolvedStateDf.selectExpr( + "key.value AS groupingKey", + "value.value4 AS count"), + Seq( + Row("a", null), + Row("b", null), + Row("c", 1), + Row("d", 1) + ) + ) + } + } + } + + test("transformWithState - verify default values during schema evolution") { + withSQLConf( + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + withTempDir { dir => + val inputData = MemoryStream[String] + + // First run with basic schema + val result1 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new DefaultValueInitialProcessor(), TimeMode.None(), OutputMode.Update()) - testStream(result3, OutputMode.Update())( - StartStream(checkpointLocation = chkptDir.getCanonicalPath), - AddData(inputData, ("a", "str3")), - CheckNewAnswer(("a", "1", "str2")), - Execute { q => - eventually(timeout(Span(5, Seconds))) { - q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) - } - }, + + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = dir.getCanonicalPath), + AddData(inputData, "test1"), + CheckNewAnswer(("test1", BasicState("test1".hashCode, "test1"))), StopStream ) - // metadata files should be kept for batches 1, 2, 3 - // schema files should be kept for batches 0, 2, 3 - assert(getFiles(metadataPath).length == 3) - assert(getFiles(stateSchemaPath).length == 3) - // we want to ensure that we can read batch 1 even though the - // metadata file for batch 0 was removed - val batch1Df = spark.read - .format("statestore") - .option(StateSourceOptions.PATH, chkptDir.getAbsolutePath) - .option(StateSourceOptions.STATE_VAR_NAME, "countState") - .option(StateSourceOptions.BATCH_ID, 1) - .load() - val batch1AnsDf = batch1Df.selectExpr( - "key.value AS groupingKey", - "value.value AS valueId") + // Second run with evolved schema to check defaults + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new DefaultValueEvolvedProcessor(), + TimeMode.None(), + OutputMode.Update()) - checkAnswer(batch1AnsDf, Seq(Row("a", 2L))) + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = dir.getCanonicalPath), - val batch3Df = spark.read - .format("statestore") - .option(StateSourceOptions.PATH, chkptDir.getAbsolutePath) - .option(StateSourceOptions.STATE_VAR_NAME, "countState") - .option(StateSourceOptions.BATCH_ID, 3) - .load() + // Check existing state - new fields should get default values + AddData(inputData, "test1"), + CheckNewAnswer( + ("test1", EvolvedState( + id = "test1".hashCode, + name = "test1", + count = 0L, + active = false, + score = 0.0 + )) + ), - val batch3AnsDf = batch3Df.selectExpr( - "key.value AS groupingKey", - "value.value AS valueId") - checkAnswer(batch3AnsDf, Seq(Row("a", 1L))) + // New state should get initialized values, not defaults + AddData(inputData, "test2"), + CheckNewAnswer( + ("test2", EvolvedState( + id = "test2".hashCode, + name = "test2", + count = 100L, + active = true, + score = 99.9 + )) + ), + StopStream + ) } } } - test("state data source integration - value state supports time travel") { + test("transformWithState - removing field should succeed") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, - SQLConf.MIN_BATCHES_TO_RETAIN.key -> "5") { + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { withTempDir { chkptDir => - // in this test case, we are changing the state spec back and forth - // to trigger the writing of the schema and metadata files - val inputData = MemoryStream[(String, String)] - val result1 = inputData.toDS() - .groupByKey(x => x._1) - .transformWithState(new RunningCountMostRecentStatefulProcessor(), - TimeMode.None(), - OutputMode.Update()) - testStream(result1, OutputMode.Update())( - StartStream(checkpointLocation = chkptDir.getCanonicalPath), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "1", "")), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "2", "str1")), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "3", "str1")), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "4", "str1")), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "5", "str1")), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "6", "str1")), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "7", "str1")), - Execute { q => - eventually(timeout(Span(5, Seconds))) { - q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) - } - }, - StopStream - ) + val dirPath = chkptDir.getCanonicalPath + val inputData = MemoryStream[String] + val result2 = inputData.toDS() - .groupByKey(x => x._1) - .transformWithState(new MostRecentStatefulProcessorWithDeletion(), + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorTwoLongs(), TimeMode.None(), OutputMode.Update()) + testStream(result2, OutputMode.Update())( - StartStream(checkpointLocation = chkptDir.getCanonicalPath), - AddData(inputData, ("a", "str2")), - CheckNewAnswer(("a", "str1")), - AddData(inputData, ("a", "str3")), - CheckNewAnswer(("a", "str2")), - AddData(inputData, ("a", "str4")), - CheckNewAnswer(("a", "str3")), - Execute { q => - eventually(timeout(Span(5, Seconds))) { - q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) - } - }, + StartStream(checkpointLocation = dirPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), StopStream ) - // Batches 0-7: countState, mostRecent - // Batches 8-9: countState - - // By this time, offset and commit logs for batches 0-3 have been purged. - // However, if we want to read the data for batch 4, we need to read the corresponding - // metadata and schema file for batch 4, or the latest files that correspond to - // batch 4 (in this case, the files were written for batch 0). - // We want to test the behavior where the metadata files are preserved so that we can - // read from the state data source, even if the commit and offset logs are purged for - // a particular batch - - val df = spark.read.format("state-metadata").load(chkptDir.toString) + val result1 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) - // check the min and max batch ids that we have data for - checkAnswer( - df.select( - "operatorId", "operatorName", "stateStoreName", "numPartitions", "minBatchId", - "maxBatchId"), - Seq(Row(0, "transformWithStateExec", "default", 5, 4, 9)) + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = dirPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream ) - - val countStateDf = spark.read - .format("statestore") - .option(StateSourceOptions.PATH, chkptDir.getAbsolutePath) - .option(StateSourceOptions.STATE_VAR_NAME, "countState") - .option(StateSourceOptions.BATCH_ID, 4) - .load() - - val countStateAnsDf = countStateDf.selectExpr( - "key.value AS groupingKey", - "value.value AS valueId") - checkAnswer(countStateAnsDf, Seq(Row("a", 5L))) - - val mostRecentDf = spark.read - .format("statestore") - .option(StateSourceOptions.PATH, chkptDir.getAbsolutePath) - .option(StateSourceOptions.STATE_VAR_NAME, "mostRecent") - .option(StateSourceOptions.BATCH_ID, 4) - .load() - - val mostRecentAnsDf = mostRecentDf.selectExpr( - "key.value AS groupingKey", - "value.value") - checkAnswer(mostRecentAnsDf, Seq(Row("a", "str1"))) } } } - // TODO: [SPARK-50845] Re-enable tests after StateSchemaV3 threshold change - ignore("transformWithState - verify that all metadata and schema logs are not purged") { + test("test that invalid schema evolution " + + "fails query for column family") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, - SQLConf.MIN_BATCHES_TO_RETAIN.key -> "3") { - withTempDir { chkptDir => - val inputData = MemoryStream[(String, String)] + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val inputData = MemoryStream[String] val result1 = inputData.toDS() - .groupByKey(x => x._1) - .transformWithState(new RunningCountMostRecentStatefulProcessor(), + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), TimeMode.None(), OutputMode.Update()) + testStream(result1, OutputMode.Update())( - StartStream(checkpointLocation = chkptDir.getCanonicalPath), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "1", "")), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "2", "str1")), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "3", "str1")), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "4", "str1")), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "5", "str1")), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "6", "str1")), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "7", "str1")), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "8", "str1")), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "9", "str1")), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "10", "str1")), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "11", "str1")), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "12", "str1")), - Execute { q => - eventually(timeout(Span(5, Seconds))) { - q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) - } - }, - StopStream - ) - testStream(result1, OutputMode.Update())( - StartStream(checkpointLocation = chkptDir.getCanonicalPath), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "13", "str1")), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "14", "str1")), - Execute { q => - eventually(timeout(Span(5, Seconds))) { - q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) - } - }, + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), StopStream ) + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorInt(), + TimeMode.None(), + OutputMode.Update()) - val stateOpIdPath = new Path(new Path(chkptDir.getCanonicalPath, "state"), "0") - val stateSchemaPath = getStateSchemaPath(stateOpIdPath) - - val metadataPath = OperatorStateMetadataV2.metadataDirPath(stateOpIdPath) - - // Metadata files exist for batches 0, 12, and the thresholdBatchId is 8 - // as this is the earliest batchId for which the commit log is not present, - // so we need to keep metadata files for batch 0 so we can read the commit - // log correspondingly - assert(getFiles(metadataPath).length == 2) - assert(getFiles(stateSchemaPath).length == 1) + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + ExpectFailure[StateStoreInvalidValueSchemaEvolution] { e => + checkError( + e.asInstanceOf[SparkUnsupportedOperationException], + condition = "STATE_STORE_INVALID_VALUE_SCHEMA_EVOLUTION", + parameters = Map( + "oldValueSchema" -> "StructType(StructField(value,LongType,true))", + "newValueSchema" -> "StructType(StructField(value,IntegerType,true))") + ) + } + ) } } } - // TODO: [SPARK-50845] Re-enable tests after StateSchemaV3 threshold change - ignore("transformWithState - verify that no metadata and schema logs are purged after" + - " removing column family") { + test("test that introducing TTL after restart fails query") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, - SQLConf.MIN_BATCHES_TO_RETAIN.key -> "3") { - withTempDir { chkptDir => - val inputData = MemoryStream[(String, String)] - val result1 = inputData.toDS() - .groupByKey(x => x._1) - .transformWithState(new RunningCountMostRecentStatefulProcessor(), - TimeMode.None(), + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val clock = new StreamManualClock + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.ProcessingTime(), OutputMode.Update()) - testStream(result1, OutputMode.Update())( - StartStream(checkpointLocation = chkptDir.getCanonicalPath), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "1", "")), - AddData(inputData, ("a", "str1")), - CheckNewAnswer(("a", "2", "str1")), - Execute { q => - eventually(timeout(Span(5, Seconds))) { - q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) - } - }, - StopStream - ) - testStream(result1, OutputMode.Update())( - StartStream(checkpointLocation = chkptDir.getCanonicalPath), - AddData(inputData, ("b", "str1")), - CheckNewAnswer(("b", "1", "")), - AddData(inputData, ("b", "str1")), - CheckNewAnswer(("b", "2", "str1")), - AddData(inputData, ("b", "str1")), - CheckNewAnswer(("b", "3", "str1")), - AddData(inputData, ("b", "str1")), - CheckNewAnswer(("b", "4", "str1")), - AddData(inputData, ("b", "str1")), - CheckNewAnswer(("b", "5", "str1")), - AddData(inputData, ("b", "str1")), - CheckNewAnswer(("b", "6", "str1")), - AddData(inputData, ("b", "str1")), - CheckNewAnswer(("b", "7", "str1")), - AddData(inputData, ("b", "str1")), - CheckNewAnswer(("b", "8", "str1")), - AddData(inputData, ("b", "str1")), - CheckNewAnswer(("b", "9", "str1")), - AddData(inputData, ("b", "str1")), - CheckNewAnswer(("b", "10", "str1")), - AddData(inputData, ("b", "str1")), - CheckNewAnswer(("b", "11", "str1")), - AddData(inputData, ("b", "str1")), - CheckNewAnswer(("b", "12", "str1")), - Execute { q => - eventually(timeout(Span(5, Seconds))) { - q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) - } - }, + + testStream(result, OutputMode.Update())( + StartStream( + trigger = Trigger.ProcessingTime("1 second"), + checkpointLocation = checkpointDir.getCanonicalPath, + triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")), + AdvanceManualClock(1 * 1000), StopStream ) val result2 = inputData.toDS() - .groupByKey(x => x._1) - .transformWithState(new MostRecentStatefulProcessorWithDeletion(), - TimeMode.None(), + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorWithTTL(), + TimeMode.ProcessingTime(), OutputMode.Update()) - testStream(result2, OutputMode.Update())( - StartStream(checkpointLocation = chkptDir.getCanonicalPath), - AddData(inputData, ("b", "str2")), - CheckNewAnswer(("b", "str1")), - AddData(inputData, ("b", "str3")), - CheckNewAnswer(("b", "str2")), - Execute { q => - eventually(timeout(Span(5, Seconds))) { - q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) - } - }, - StopStream + StartStream( + trigger = Trigger.ProcessingTime("1 second"), + checkpointLocation = checkpointDir.getCanonicalPath, + triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + ExpectFailure[StateStoreInvalidValueSchemaEvolution] { e => + checkError( + e.asInstanceOf[SparkUnsupportedOperationException], + condition = "STATE_STORE_INVALID_VALUE_SCHEMA_EVOLUTION", + parameters = Map( + "newValueSchema" -> ("StructType(StructField(value,StructType(StructField(" + + "value,LongType,true)),true),StructField(ttlExpirationMs,LongType,true))"), + "oldValueSchema" -> "StructType(StructField(value,LongType,true))") + ) + } ) - - val stateOpIdPath = new Path(new Path(chkptDir.getCanonicalPath, "state"), "0") - val stateSchemaPath = getStateSchemaPath(stateOpIdPath) - - val metadataPath = OperatorStateMetadataV2.metadataDirPath(stateOpIdPath) - - // Metadata files are written for batches 0, 2, and 14. - // Schema files are written for 0, 14 - // At the beginning of the last query run, the thresholdBatchId is 11. - // However, we would need both schema files to be preserved, if we want to - // be able to read from batch 11 onwards. - assert(getFiles(metadataPath).length == 2) - assert(getFiles(stateSchemaPath).length == 2) } } } - testWithEncoding("avro")("transformWithState - incompatible schema evolution should fail") { + test("transformWithState - incompatible schema evolution should fail") { withSQLConf( SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> "avro", @@ -2998,62 +2981,109 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } -class TransformWithStateValidationSuite extends StateStoreMetricsTest { - import testImplicits._ +class TransformWithStateUnsafeRowEncodingSuite extends TransformWithStateTest { - test("transformWithState - streaming with hdfsStateStoreProvider should fail") { - val inputData = MemoryStream[String] - val result = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new RunningCountStatefulProcessor(), - TimeMode.None(), - OutputMode.Update()) + import testImplicits._ - testStream(result, OutputMode.Update())( - AddData(inputData, "a"), - ExpectFailure[StateStoreMultipleColumnFamiliesNotSupportedException] { t => - assert(t.getMessage.contains("not supported")) + override protected def test(testName: String, testTags: Tag*)(testBody: => Any) + (implicit pos: Position): Unit = { + super.test(s"$testName (encoding = UnsafeRow)", testTags: _*) { + withSQLConf(SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> "unsaferow") { + testBody } - ) + } } - test("transformWithStateWithInitialState - streaming with hdfsStateStoreProvider should fail") { - val inputData = MemoryStream[InitInputRow] - val initDf = Seq(("init_1", 40.0), ("init_2", 100.0)).toDS() - .groupByKey(x => x._1) - .mapValues(x => x) - val result = inputData.toDS() - .groupByKey(x => x.key) - .transformWithState(new AccumulateStatefulProcessorWithInitState(), - TimeMode.None(), OutputMode.Append(), initDf - ) - testStream(result, OutputMode.Update())( - AddData(inputData, InitInputRow("a", "add", -1.0)), - ExpectFailure[StateStoreMultipleColumnFamiliesNotSupportedException] { - (t: Throwable) => { - assert(t.getMessage.contains("not supported")) - } - } - ) - } + test("test that invalid schema evolution " + + "fails query for column family") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val result1 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) - test("transformWithState - validate timeModes") { - // validation tests should pass for TimeMode.None - TransformWithStateVariableUtils.validateTimeMode(TimeMode.None(), None) - TransformWithStateVariableUtils.validateTimeMode(TimeMode.None(), Some(10L)) + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream + ) + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorInt(), + TimeMode.None(), + OutputMode.Update()) - // validation tests should fail for TimeMode.ProcessingTime and TimeMode.EventTime - // when time values are not provided - val ex = intercept[SparkException] { - TransformWithStateVariableUtils.validateTimeMode(TimeMode.ProcessingTime(), None) + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a"), + ExpectFailure[StateStoreValueSchemaNotCompatible] { + (t: Throwable) => { + assert(t.getMessage.contains("Please check number and type of fields.")) + } + } + ) + } } - assert(ex.getMessage.contains("Failed to find time values")) - TransformWithStateVariableUtils.validateTimeMode(TimeMode.ProcessingTime(), Some(10L)) + } - val ex1 = intercept[SparkException] { - TransformWithStateVariableUtils.validateTimeMode(TimeMode.EventTime(), None) + test("test that introducing TTL after restart fails query") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val clock = new StreamManualClock + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.ProcessingTime(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream( + trigger = Trigger.ProcessingTime("1 second"), + checkpointLocation = checkpointDir.getCanonicalPath, + triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")), + AdvanceManualClock(1 * 1000), + StopStream + ) + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorWithTTL(), + TimeMode.ProcessingTime(), + OutputMode.Update()) + testStream(result2, OutputMode.Update())( + StartStream( + trigger = Trigger.ProcessingTime("1 second"), + checkpointLocation = checkpointDir.getCanonicalPath, + triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + ExpectFailure[StateStoreValueSchemaNotCompatible] { t => + checkError( + t.asInstanceOf[SparkUnsupportedOperationException], + condition = "STATE_STORE_VALUE_SCHEMA_NOT_COMPATIBLE", + parameters = Map( + "storedValueSchema" -> "StructType(StructField(value,LongType,false))", + "newValueSchema" -> + ("StructType(StructField(value,StructType(StructField(value,LongType,false))," + + "true),StructField(ttlExpirationMs,LongType,true))") + ) + ) + } + ) + } } - assert(ex1.getMessage.contains("Failed to find time values")) - TransformWithStateVariableUtils.validateTimeMode(TimeMode.EventTime(), Some(10L)) } }