Skip to content

Commit 6794e15

Browse files
ericm-dbHeartSaVioR
authored andcommitted
[SPARK-51065][SQL] Disallowing non-nullable schema when Avro encoding is used for TransformWithState
### What changes were proposed in this pull request? Right now, effectively set all fields in a schema to nullable, regardless of what the user specifies. - However, when Avro encoding is used, we want to enforce nullability in order to enable the schema evolution cases we support. - Nullability can only be set by the user in Python, so when non-nullable fields are defined, we throw an error - In Scala, Encoders.product set fields to non-nullable by default (user cannot configure this), so we turn the fields to nullable ### Why are the changes needed? In order to keep parity with the user-specified schema with the actual schema that we use, and to enable the schema evolution use cases we want ### Does this PR introduce _any_ user-facing change? This error is thrown if the schema is defined as non-nullable ``` Traceback (most recent call last): File "/Users/eric.marnadi/spark/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py", line 1496, in test_not_nullable_fails self._run_evolution_test( File "/Users/eric.marnadi/spark/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py", line 1344, in _run_evolution_test q.processAllAvailable() File "/Users/eric.marnadi/spark/python/pyspark/sql/streaming/query.py", line 351, in processAllAvailable return self._jsq.processAllAvailable() File "/Users/eric.marnadi/spark/python/lib/py4j-0.10.9.9-src.zip/py4j/java_gateway.py", line 1362, in __call__ return_value = get_return_value( File "/Users/eric.marnadi/spark/python/pyspark/errors/exceptions/captured.py", line 258, in deco raise converted from None pyspark.errors.exceptions.captured.StreamingQueryException: [STREAM_FAILED] Query [id = 541c5df0-24e4-4702-b87a-c4edfb6a952c, runId = 4259c7b9-3846-4f73-9204-c3d71b07018c] terminated with exception: [STATE_STORE_SCHEMA_MUST_BE_NULLABLE] If schema evolution is enabled, all the fields in the schema for column family state must be nullable. Please set the 'spark.sql.streaming.stateStore.encodingFormat' to 'UnsafeRow' or make the schema nullable. Current schema: StructType(StructField(id,IntegerType,false),StructField(name,StringType,false)) SQLSTATE: XXKST SQLSTATE: XXKST === Streaming Query === Identifier: evolution_test [id = 541c5df0-24e4-4702-b87a-c4edfb6a952c, runId = 4259c7b9-3846-4f73-9204-c3d71b07018c] Current Committed Offsets: {} Current ``` ### How was this patch tested? Unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #49751 from ericm-db/disallow-non-nullable-schema. Authored-by: Eric Marnadi <[email protected]> Signed-off-by: Jungtaek Lim <[email protected]> (cherry picked from commit 301b666) Signed-off-by: Jungtaek Lim <[email protected]>
1 parent ffaab48 commit 6794e15

File tree

12 files changed

+128
-21
lines changed

12 files changed

+128
-21
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5059,6 +5059,14 @@
50595059
],
50605060
"sqlState" : "42601"
50615061
},
5062+
"TRANSFORM_WITH_STATE_SCHEMA_MUST_BE_NULLABLE" : {
5063+
"message" : [
5064+
"If Avro encoding is enabled, all the fields in the schema for column family <columnFamilyName> must be nullable",
5065+
"when using the TransformWithState operator.",
5066+
"Please make the schema nullable. Current schema: <schema>"
5067+
],
5068+
"sqlState" : "XXKST"
5069+
},
50625070
"TRANSPOSE_EXCEED_ROW_LIMIT" : {
50635071
"message" : [
50645072
"Number of rows exceeds the allowed limit of <maxValues> for TRANSPOSE. If this was intended, set <config> to at least the current row count."

python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,6 +1470,39 @@ def check_exception(error):
14701470
check_exception=check_exception,
14711471
)
14721472

1473+
def test_not_nullable_fails(self):
1474+
with self.sql_conf({"spark.sql.streaming.stateStore.encodingFormat": "avro"}):
1475+
with tempfile.TemporaryDirectory() as checkpoint_dir:
1476+
input_path = tempfile.mkdtemp()
1477+
self._prepare_test_resource1(input_path)
1478+
1479+
df = self._build_test_df(input_path)
1480+
1481+
def check_basic_state(batch_df, batch_id):
1482+
result = batch_df.collect()[0]
1483+
assert result.value["id"] == 0 # First ID from test data
1484+
assert result.value["name"] == "name-0"
1485+
1486+
def check_exception(error):
1487+
from pyspark.errors.exceptions.captured import StreamingQueryException
1488+
1489+
if not isinstance(error, StreamingQueryException):
1490+
return False
1491+
1492+
error_msg = str(error)
1493+
return (
1494+
"[TRANSFORM_WITH_STATE_SCHEMA_MUST_BE_NULLABLE]" in error_msg
1495+
and "column family state must be nullable" in error_msg
1496+
)
1497+
1498+
self._run_evolution_test(
1499+
BasicProcessorNotNullable(),
1500+
checkpoint_dir,
1501+
check_basic_state,
1502+
df,
1503+
check_exception=check_exception,
1504+
)
1505+
14731506

14741507
class SimpleStatefulProcessorWithInitialState(StatefulProcessor):
14751508
# this dict is the same as input initial state dataframe
@@ -1893,6 +1926,27 @@ def close(self) -> None:
18931926
pass
18941927

18951928

1929+
class BasicProcessorNotNullable(StatefulProcessor):
1930+
# Schema definitions
1931+
state_schema = StructType(
1932+
[StructField("id", IntegerType(), False), StructField("name", StringType(), False)]
1933+
)
1934+
1935+
def init(self, handle):
1936+
self.state = handle.getValueState("state", self.state_schema)
1937+
1938+
def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
1939+
for pdf in rows:
1940+
pass
1941+
id_val = int(key[0])
1942+
name = f"name-{id_val}"
1943+
self.state.update((id_val, name))
1944+
yield pd.DataFrame({"id": [key[0]], "value": [{"id": id_val, "name": name}]})
1945+
1946+
def close(self) -> None:
1947+
pass
1948+
1949+
18961950
class AddFieldsProcessor(StatefulProcessor):
18971951
state_schema = StructType(
18981952
[

sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,14 @@ case class TransformWithStateInPandasExec(
115115
override def operatorStateMetadataVersion: Int = 2
116116

117117
override def getColFamilySchemas(
118-
setNullableFields: Boolean
118+
shouldBeNullable: Boolean
119119
): Map[String, StateStoreColFamilySchema] = {
120-
driverProcessorHandle.getColumnFamilySchemas(setNullableFields)
120+
// For Python, the user can explicitly set nullability on schema, so
121+
// we need to throw an error if the schema is nullable
122+
driverProcessorHandle.getColumnFamilySchemas(
123+
shouldCheckNullable = shouldBeNullable,
124+
shouldSetNullable = shouldBeNullable
125+
)
121126
}
122127

123128
override def getStateVariableInfos(): Map[String, TransformWithStateVariableInfo] = {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateMetricsImpl.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ trait ListStateMetricsImpl {
3838
// We keep track of the count of entries in the list in a separate column family
3939
// to avoid scanning the entire list to get the count.
4040
private val counterCFValueSchema: StructType =
41-
StructType(Seq(StructField("count", LongType, nullable = false)))
41+
StructType(Seq(StructField("count", LongType, nullable = true)))
4242

4343
private val counterCFProjection = UnsafeProjection.create(counterCFValueSchema)
4444

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ object StateStoreColumnFamilySchemaUtils {
4747
// Byte type is converted to Int in Avro, which doesn't work for us as Avro
4848
// uses zig-zag encoding as opposed to big-endian for Ints
4949
Seq(
50-
StructField(s"${field.name}_marker", BinaryType, nullable = false),
50+
StructField(s"${field.name}_marker", BinaryType, nullable = true),
5151
field.copy(name = s"${field.name}_value", BinaryType)
5252
)
5353
} else {
@@ -117,7 +117,7 @@ object StateStoreColumnFamilySchemaUtils {
117117
getRowCounterCFName(stateName), keySchemaId = 0,
118118
keyEncoder.schema,
119119
valueSchemaId = 0,
120-
StructType(Seq(StructField("count", LongType, nullable = false))),
120+
StructType(Seq(StructField("count", LongType, nullable = true))),
121121
Some(NoPrefixKeyStateEncoderSpec(keyEncoder.schema)))
122122
schemas.put(counterSchema.colFamilyName, counterSchema)
123123

@@ -149,7 +149,7 @@ object StateStoreColumnFamilySchemaUtils {
149149
keySchemaId = 0,
150150
keyEncoder.schema,
151151
valueSchemaId = 0,
152-
StructType(Seq(StructField("count", LongType, nullable = false))),
152+
StructType(Seq(StructField("count", LongType, nullable = true))),
153153
Some(NoPrefixKeyStateEncoderSpec(keyEncoder.schema)))
154154
schemas.put(countSchema.colFamilyName, countSchema)
155155
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -363,18 +363,38 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
363363
addTimerColFamily()
364364
}
365365

366+
/**
367+
* This method returns all column family schemas, and checks and enforces nullability
368+
* if need be. The nullability check and set is only set to true when Avro is enabled.
369+
* @param shouldCheckNullable Whether we need to check the nullability. This is set to
370+
* true when using Python, as this is the only avenue through
371+
* which users can set nullability
372+
* @param shouldSetNullable Whether we need to set the fields as nullable. This is set to
373+
* true when using Scala, as primitive type encoders set the field
374+
* to non-nullable. Changing fields from non-nullable to nullable
375+
* does not break anything (and is required for Avro encoding), so
376+
* we can safely make this change.
377+
* @return column family schemas used by this stateful processor.
378+
*/
366379
def getColumnFamilySchemas(
367-
setNullableFields: Boolean
380+
shouldCheckNullable: Boolean,
381+
shouldSetNullable: Boolean
368382
): Map[String, StateStoreColFamilySchema] = {
369383
val schemas = columnFamilySchemas.toMap
370-
if (setNullableFields) {
371-
schemas.map { case (colFamilyName, stateStoreColFamilySchema) =>
372-
colFamilyName -> stateStoreColFamilySchema.copy(
373-
valueSchema = stateStoreColFamilySchema.valueSchema.toNullable
384+
schemas.map { case (colFamilyName, schema) =>
385+
schema.valueSchema.fields.foreach { field =>
386+
if (!field.nullable && shouldCheckNullable) {
387+
throw StateStoreErrors.twsSchemaMustBeNullable(
388+
schema.colFamilyName, schema.valueSchema.toString())
389+
}
390+
}
391+
if (shouldSetNullable) {
392+
colFamilyName -> schema.copy(
393+
valueSchema = schema.valueSchema.toNullable
374394
)
395+
} else {
396+
colFamilyName -> schema
375397
}
376-
} else {
377-
schemas
378398
}
379399
}
380400

@@ -549,7 +569,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
549569
elementKeySchema: StructType): StateStoreColFamilySchema = {
550570
val countIndexName = s"$$count_$stateName"
551571
val countValueSchema = StructType(Seq(
552-
StructField("count", LongType, nullable = false)
572+
StructField("count", LongType)
553573
))
554574

555575
StateStoreColFamilySchema(

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ abstract class OneToManyTTLState(
357357
// Schema of the entry count index: elementKey -> count
358358
private val COUNT_INDEX = "$count_" + stateName
359359
private val COUNT_INDEX_VALUE_SCHEMA: StructType =
360-
StructType(Seq(StructField("count", LongType, nullable = false)))
360+
StructType(Seq(StructField("count", LongType)))
361361
private val countIndexValueProjector = UnsafeProjection.create(COUNT_INDEX_VALUE_SCHEMA)
362362

363363
// Reused internal row that we use to create an UnsafeRow with the schema of

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ case class TransformWithStateExec(
140140
* after init is called.
141141
*/
142142
override def getColFamilySchemas(
143-
setNullableFields: Boolean
143+
shouldBeNullable: Boolean
144144
): Map[String, StateStoreColFamilySchema] = {
145145
val keySchema = keyExpressions.toStructType
146146
// we have to add the default column family schema because the RocksDBStateEncoder
@@ -149,8 +149,11 @@ case class TransformWithStateExec(
149149
0, keyExpressions.toStructType, 0, DUMMY_VALUE_ROW_SCHEMA,
150150
Some(NoPrefixKeyStateEncoderSpec(keySchema)))
151151

152+
// For Scala, the user can't explicitly set nullability on schema, so there is
153+
// no reason to throw an error, and we can simply set the schema to nullable.
152154
val columnFamilySchemas = getDriverProcessorHandle()
153-
.getColumnFamilySchemas(setNullableFields) ++
155+
.getColumnFamilySchemas(
156+
shouldCheckNullable = false, shouldSetNullable = shouldBeNullable) ++
154157
Map(StateStore.DEFAULT_COL_FAMILY_NAME -> defaultSchema)
155158
closeProcessorHandle()
156159
columnFamilySchemas

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,9 @@ object TransformWithStateOperatorProperties extends Logging {
175175
*/
176176
trait TransformWithStateMetadataUtils extends Logging {
177177

178-
def getColFamilySchemas(setNullableFields: Boolean): Map[String, StateStoreColFamilySchema]
178+
// This method will return the column family schemas, and check whether the fields in the
179+
// schema are nullable. If Avro encoding is used, we want to enforce nullability
180+
def getColFamilySchemas(shouldBeNullable: Boolean): Map[String, StateStoreColFamilySchema]
179181

180182
def getStateVariableInfos(): Map[String, TransformWithStateVariableInfo]
181183

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,12 @@ object StateStoreErrors {
145145
new StateStoreValueSchemaNotCompatible(storedValueSchema, newValueSchema)
146146
}
147147

148+
def twsSchemaMustBeNullable(
149+
columnFamilyName: String,
150+
schema: String): TWSSchemaMustBeNullable = {
151+
new TWSSchemaMustBeNullable(columnFamilyName, schema)
152+
}
153+
148154
def stateStoreInvalidValueSchemaEvolution(
149155
oldValueSchema: String,
150156
newValueSchema: String): StateStoreInvalidValueSchemaEvolution = {
@@ -346,6 +352,15 @@ class StateStoreValueSchemaNotCompatible(
346352
"storedValueSchema" -> storedValueSchema,
347353
"newValueSchema" -> newValueSchema))
348354

355+
class TWSSchemaMustBeNullable(
356+
columnFamilyName: String,
357+
schema: String)
358+
extends SparkUnsupportedOperationException(
359+
errorClass = "TRANSFORM_WITH_STATE_SCHEMA_MUST_BE_NULLABLE",
360+
messageParameters = Map(
361+
"columnFamilyName" -> columnFamilyName,
362+
"schema" -> schema))
363+
349364
class StateStoreInvalidValueSchemaEvolution(
350365
oldValueSchema: String,
351366
newValueSchema: String)

sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1406,7 +1406,7 @@ abstract class TransformWithStateSuite extends StateStoreMetricsTest
14061406
val schema3 = StateStoreColFamilySchema(
14071407
"$rowCounter_listState", 0,
14081408
keySchema, 0,
1409-
new StructType().add("count", LongType, nullable = shouldBeNullable),
1409+
new StructType().add("count", LongType, nullable = true),
14101410
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
14111411
None
14121412
)

sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,15 +317,15 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
317317
val schema2 = StateStoreColFamilySchema(
318318
"$count_listState", 0,
319319
keySchema, 0,
320-
new StructType().add("count", LongType, nullable = shouldBeNullable),
320+
new StructType().add("count", LongType, nullable = true),
321321
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
322322
None
323323
)
324324

325325
val schema3 = StateStoreColFamilySchema(
326326
"$rowCounter_listState", 0,
327327
keySchema, 0,
328-
new StructType().add("count", LongType, nullable = shouldBeNullable),
328+
new StructType().add("count", LongType, nullable = true),
329329
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
330330
None
331331
)

0 commit comments

Comments
 (0)