Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-51065][SQL] Disallowing non-nullable schema when Avro encoding is used for TransformWithState #49751

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -4669,6 +4669,13 @@
],
"sqlState" : "42K06"
},
"STATE_STORE_SCHEMA_MUST_BE_NULLABLE" : {
"message" : [
"If schema evolution is enabled, all the fields in the schema for column family <columnFamilyName> must be nullable.",
ericm-db marked this conversation as resolved.
Show resolved Hide resolved
"Please make the schema nullable. Current schema: <schema>"
],
"sqlState" : "XXKST"
},
"STATE_STORE_STATE_SCHEMA_FILES_THRESHOLD_EXCEEDED" : {
"message" : [
"The number of state schema files <numStateSchemaFiles> exceeds the maximum number of state schema files for this query: <maxStateSchemaFiles>.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1470,6 +1470,39 @@ def check_exception(error):
check_exception=check_exception,
)

def test_not_nullable_fails(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not having identical test in Scala as well? I don't see a new test verifying the error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is, there is no way for user to specify this using Scala.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes and probably also no.

I agree moderate users may not ever try to get over and just stick with case class or POJO or so. But "we" can imagine a way to get over, exactly the same way how we could support PySpark:

override protected val stateEncoder: ExpressionEncoder[Any] =
    ExpressionEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]]

This is how we come up with state encoder for Python version of FMGWS. This is to serde with Row interface - my rough memory says it's not InternalRow but Row, so, most likely work with GenericRow, but we can try with both GenericRow and InternalRow.

I'm OK with deferring this as follow-up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I get that you are not able to test this actually, as we have to just accept non-nullable column and change to nullable. Again I doubt this is just a bug though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Not really a bug, we figured out.)

with self.sql_conf({"spark.sql.streaming.stateStore.encodingFormat": "avro"}):
with tempfile.TemporaryDirectory() as checkpoint_dir:
input_path = tempfile.mkdtemp()
self._prepare_test_resource1(input_path)

df = self._build_test_df(input_path)

def check_basic_state(batch_df, batch_id):
result = batch_df.collect()[0]
assert result.value["id"] == 0 # First ID from test data
assert result.value["name"] == "name-0"

def check_exception(error):
from pyspark.errors.exceptions.captured import StreamingQueryException

if not isinstance(error, StreamingQueryException):
return False

error_msg = str(error)
return (
"[STATE_STORE_SCHEMA_MUST_BE_NULLABLE]" in error_msg
and "column family state must be nullable" in error_msg
)

self._run_evolution_test(
BasicProcessorNotNullable(),
checkpoint_dir,
check_basic_state,
df,
check_exception=check_exception,
)


class SimpleStatefulProcessorWithInitialState(StatefulProcessor):
# this dict is the same as input initial state dataframe
Expand Down Expand Up @@ -1893,6 +1926,27 @@ def close(self) -> None:
pass


class BasicProcessorNotNullable(StatefulProcessor):
# Schema definitions
state_schema = StructType(
[StructField("id", IntegerType(), False), StructField("name", StringType(), False)]
)

def init(self, handle):
self.state = handle.getValueState("state", self.state_schema)

def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
for pdf in rows:
pass
id_val = int(key[0])
name = f"name-{id_val}"
self.state.update((id_val, name))
yield pd.DataFrame({"id": [key[0]], "value": [{"id": id_val, "name": name}]})

def close(self) -> None:
pass


class AddFieldsProcessor(StatefulProcessor):
state_schema = StructType(
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ case class TransformWithStateInPandasExec(
override def getColFamilySchemas(
setNullableFields: Boolean
): Map[String, StateStoreColFamilySchema] = {
driverProcessorHandle.getColumnFamilySchemas(setNullableFields)
driverProcessorHandle.getColumnFamilySchemas(true)
ericm-db marked this conversation as resolved.
Show resolved Hide resolved
}

override def getStateVariableInfos(): Map[String, TransformWithStateVariableInfo] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ trait ListStateMetricsImpl {
// We keep track of the count of entries in the list in a separate column family
// to avoid scanning the entire list to get the count.
private val counterCFValueSchema: StructType =
StructType(Seq(StructField("count", LongType, nullable = false)))
StructType(Seq(StructField("count", LongType)))

private val counterCFProjection = UnsafeProjection.create(counterCFValueSchema)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ object StateStoreColumnFamilySchemaUtils {
// Byte type is converted to Int in Avro, which doesn't work for us as Avro
// uses zig-zag encoding as opposed to big-endian for Ints
Seq(
StructField(s"${field.name}_marker", BinaryType, nullable = false),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets say nullable=true explicitly ?

StructField(s"${field.name}_marker", BinaryType),
field.copy(name = s"${field.name}_value", BinaryType)
)
} else {
Expand Down Expand Up @@ -117,7 +117,7 @@ object StateStoreColumnFamilySchemaUtils {
getRowCounterCFName(stateName), keySchemaId = 0,
keyEncoder.schema,
valueSchemaId = 0,
StructType(Seq(StructField("count", LongType, nullable = false))),
StructType(Seq(StructField("count", LongType))),
Some(NoPrefixKeyStateEncoderSpec(keyEncoder.schema)))
schemas.put(counterSchema.colFamilyName, counterSchema)

Expand Down Expand Up @@ -149,7 +149,7 @@ object StateStoreColumnFamilySchemaUtils {
keySchemaId = 0,
keyEncoder.schema,
valueSchemaId = 0,
StructType(Seq(StructField("count", LongType, nullable = false))),
StructType(Seq(StructField("count", LongType))),
Some(NoPrefixKeyStateEncoderSpec(keyEncoder.schema)))
schemas.put(countSchema.colFamilyName, countSchema)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,17 +364,20 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
}

def getColumnFamilySchemas(
setNullableFields: Boolean
shouldCheckNullable: Boolean
): Map[String, StateStoreColFamilySchema] = {
val schemas = columnFamilySchemas.toMap
if (setNullableFields) {
schemas.map { case (colFamilyName, stateStoreColFamilySchema) =>
colFamilyName -> stateStoreColFamilySchema.copy(
valueSchema = stateStoreColFamilySchema.valueSchema.toNullable
)
schemas.map { case (colFamilyName, schema) =>
// assert that each field is nullable if schema evolution is enabled
schema.valueSchema.fields.foreach { field =>
if (!field.nullable && shouldCheckNullable) {
throw StateStoreErrors.stateStoreSchemaMustBeNullable(
schema.colFamilyName, schema.valueSchema.toString())
}
}
} else {
schemas
colFamilyName -> schema.copy(
ericm-db marked this conversation as resolved.
Show resolved Hide resolved
valueSchema = schema.valueSchema.toNullable
)
}
}

Expand Down Expand Up @@ -549,7 +552,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
elementKeySchema: StructType): StateStoreColFamilySchema = {
val countIndexName = s"$$count_$stateName"
val countValueSchema = StructType(Seq(
StructField("count", LongType, nullable = false)
StructField("count", LongType)
))

StateStoreColFamilySchema(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ abstract class OneToManyTTLState(
// Schema of the entry count index: elementKey -> count
private val COUNT_INDEX = "$count_" + stateName
private val COUNT_INDEX_VALUE_SCHEMA: StructType =
StructType(Seq(StructField("count", LongType, nullable = false)))
StructType(Seq(StructField("count", LongType)))
private val countIndexValueProjector = UnsafeProjection.create(COUNT_INDEX_VALUE_SCHEMA)

// Reused internal row that we use to create an UnsafeRow with the schema of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ case class TransformWithStateExec(
Some(NoPrefixKeyStateEncoderSpec(keySchema)))

val columnFamilySchemas = getDriverProcessorHandle()
.getColumnFamilySchemas(setNullableFields) ++
.getColumnFamilySchemas(false) ++
ericm-db marked this conversation as resolved.
Show resolved Hide resolved
Map(StateStore.DEFAULT_COL_FAMILY_NAME -> defaultSchema)
closeProcessorHandle()
columnFamilySchemas
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ object StateStoreErrors {
new StateStoreValueSchemaNotCompatible(storedValueSchema, newValueSchema)
}

def stateStoreSchemaMustBeNullable(
columnFamilyName: String,
schema: String): StateStoreSchemaMustBeNullable = {
new StateStoreSchemaMustBeNullable(columnFamilyName, schema)
}

def stateStoreInvalidValueSchemaEvolution(
oldValueSchema: String,
newValueSchema: String): StateStoreInvalidValueSchemaEvolution = {
Expand Down Expand Up @@ -346,6 +352,15 @@ class StateStoreValueSchemaNotCompatible(
"storedValueSchema" -> storedValueSchema,
"newValueSchema" -> newValueSchema))

class StateStoreSchemaMustBeNullable(
columnFamilyName: String,
schema: String)
extends SparkUnsupportedOperationException(
errorClass = "STATE_STORE_SCHEMA_MUST_BE_NULLABLE",
messageParameters = Map(
"columnFamilyName" -> columnFamilyName,
"schema" -> schema))

class StateStoreInvalidValueSchemaEvolution(
oldValueSchema: String,
newValueSchema: String)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1815,7 +1815,6 @@ class TransformWithStateSuite extends StateStoreMetricsTest
TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
withTempDir { checkpointDir =>
// When Avro is used, we want to set the StructFields to nullable
val shouldBeNullable = usingAvroEncoding()
ericm-db marked this conversation as resolved.
Show resolved Hide resolved
val metadataPathPostfix = "state/0/_stateSchema/default"
val stateSchemaPath = new Path(checkpointDir.toString,
s"$metadataPathPostfix")
Expand All @@ -1826,15 +1825,15 @@ class TransformWithStateSuite extends StateStoreMetricsTest
val schema0 = StateStoreColFamilySchema(
"countState", 0,
keySchema, 0,
new StructType().add("value", LongType, nullable = shouldBeNullable),
ericm-db marked this conversation as resolved.
Show resolved Hide resolved
new StructType().add("value", LongType, nullable = true),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)
val schema1 = StateStoreColFamilySchema(
"listState", 0,
keySchema, 0,
new StructType()
.add("id", LongType, nullable = shouldBeNullable)
.add("id", LongType, nullable = true)
.add("name", StringType),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
Expand All @@ -1857,7 +1856,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest
val schema3 = StateStoreColFamilySchema(
"$rowCounter_listState", 0,
keySchema, 0,
new StructType().add("count", LongType, nullable = shouldBeNullable),
new StructType().add("count", LongType, nullable = true),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)
Expand Down Expand Up @@ -2209,9 +2208,9 @@ class TransformWithStateSuite extends StateStoreMetricsTest
t.asInstanceOf[SparkUnsupportedOperationException],
condition = "STATE_STORE_VALUE_SCHEMA_NOT_COMPATIBLE",
parameters = Map(
"storedValueSchema" -> "StructType(StructField(value,LongType,false))",
"storedValueSchema" -> "StructType(StructField(value,LongType,true))",
"newValueSchema" ->
("StructType(StructField(value,StructType(StructField(value,LongType,false))," +
("StructType(StructField(value,StructType(StructField(value,LongType,true))," +
"true),StructField(ttlExpirationMs,LongType,true))")
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
) {
withTempDir { checkpointDir =>
// When Avro is used, we want to set the StructFields to nullable
val shouldBeNullable = usingAvroEncoding()
val metadataPathPostfix = "state/0/_stateSchema/default"
val stateSchemaPath = new Path(checkpointDir.toString, s"$metadataPathPostfix")
val hadoopConf = spark.sessionState.newHadoopConf()
Expand Down Expand Up @@ -317,15 +316,15 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
val schema2 = StateStoreColFamilySchema(
"$count_listState", 0,
keySchema, 0,
new StructType().add("count", LongType, nullable = shouldBeNullable),
new StructType().add("count", LongType, nullable = true),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)

val schema3 = StateStoreColFamilySchema(
"$rowCounter_listState", 0,
keySchema, 0,
new StructType().add("count", LongType, nullable = shouldBeNullable),
new StructType().add("count", LongType, nullable = true),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)
Expand Down Expand Up @@ -409,7 +408,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
"valueStateTTL", 0,
keySchema, 0,
new StructType()
.add("value", new StructType().add("value", IntegerType, nullable = shouldBeNullable))
.add("value", new StructType().add("value", IntegerType, nullable = true))
.add("ttlExpirationMs", LongType),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
Expand All @@ -418,7 +417,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
val schema10 = StateStoreColFamilySchema(
"valueState", 0,
keySchema, 0,
new StructType().add("value", IntegerType, nullable = shouldBeNullable),
new StructType().add("value", IntegerType, nullable = true),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)
Expand All @@ -428,7 +427,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
keySchema, 0,
new StructType()
.add("value", new StructType()
.add("id", LongType, nullable = shouldBeNullable)
.add("id", LongType, nullable = true)
.add("name", StringType))
.add("ttlExpirationMs", LongType),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
Expand Down