Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-51065][SQL] Disallowing non-nullable schema when Avro encoding is used for TransformWithState #49751

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -4669,6 +4669,13 @@
],
"sqlState" : "42K06"
},
"STATE_STORE_SCHEMA_MUST_BE_NULLABLE" : {
"message" : [
"If schema evolution is enabled, all the fields in the schema for column family <columnFamilyName> must be nullable.",
ericm-db marked this conversation as resolved.
Show resolved Hide resolved
"Please make the schema nullable. Current schema: <schema>"
],
"sqlState" : "XXKST"
},
"STATE_STORE_STATE_SCHEMA_FILES_THRESHOLD_EXCEEDED" : {
"message" : [
"The number of state schema files <numStateSchemaFiles> exceeds the maximum number of state schema files for this query: <maxStateSchemaFiles>.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1469,7 +1469,37 @@ def check_exception(error):
df,
check_exception=check_exception,
)
def test_not_nullable_fails(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not having identical test in Scala as well? I don't see a new test verifying the error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is, there is no way for user to specify this using Scala.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes and probably also no.

I agree moderate users may not ever try to get over and just stick with case class or POJO or so. But "we" can imagine a way to get over, exactly the same way how we could support PySpark:

override protected val stateEncoder: ExpressionEncoder[Any] =
    ExpressionEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]]

This is how we come up with state encoder for Python version of FMGWS. This is to serde with Row interface - my rough memory says it's not InternalRow but Row, so, most likely work with GenericRow, but we can try with both GenericRow and InternalRow.

I'm OK with deferring this as follow-up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I get that you are not able to test this actually, as we have to just accept non-nullable column and change to nullable. Again I doubt this is just a bug though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Not really a bug, we figured out.)

with self.sql_conf({"spark.sql.streaming.stateStore.encodingFormat": "avro"}):
with tempfile.TemporaryDirectory() as checkpoint_dir:
input_path = tempfile.mkdtemp()
self._prepare_test_resource1(input_path)

df = self._build_test_df(input_path)

def check_basic_state(batch_df, batch_id):
result = batch_df.collect()[0]
assert result.value["id"] == 0 # First ID from test data
assert result.value["name"] == "name-0"

def check_exception(error):
from pyspark.errors.exceptions.captured import StreamingQueryException

if not isinstance(error, StreamingQueryException):
return False

error_msg = str(error)
return (
"[STATE_STORE_SCHEMA_MUST_BE_NULLABLE]" in error_msg
and "column family state must be nullable" in error_msg
)
self._run_evolution_test(
BasicProcessorNotNullable(),
checkpoint_dir,
check_basic_state,
df,
check_exception=check_exception
)

class SimpleStatefulProcessorWithInitialState(StatefulProcessor):
# this dict is the same as input initial state dataframe
Expand Down Expand Up @@ -1892,6 +1922,26 @@ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
def close(self) -> None:
pass

class BasicProcessorNotNullable(StatefulProcessor):
# Schema definitions
state_schema = StructType(
[StructField("id", IntegerType(), False), StructField("name", StringType(), False)]
)

def init(self, handle):
self.state = handle.getValueState("state", self.state_schema)

def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
for pdf in rows:
pass
id_val = int(key[0])
name = f"name-{id_val}"
self.state.update((id_val, name))
yield pd.DataFrame({"id": [key[0]], "value": [{"id": id_val, "name": name}]})

def close(self) -> None:
pass


class AddFieldsProcessor(StatefulProcessor):
state_schema = StructType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ case class TransformWithStateInPandasExec(
override def getColFamilySchemas(
setNullableFields: Boolean
): Map[String, StateStoreColFamilySchema] = {
driverProcessorHandle.getColumnFamilySchemas(setNullableFields)
driverProcessorHandle.getColumnFamilySchemas(true)
ericm-db marked this conversation as resolved.
Show resolved Hide resolved
}

override def getStateVariableInfos(): Map[String, TransformWithStateVariableInfo] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,18 +363,25 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
addTimerColFamily()
}

private def isInternal(columnFamilyName: String): Boolean = {
columnFamilyName.startsWith("_") || columnFamilyName.startsWith("$")
}

def getColumnFamilySchemas(
setNullableFields: Boolean
shouldCheckNullable: Boolean
): Map[String, StateStoreColFamilySchema] = {
val schemas = columnFamilySchemas.toMap
if (setNullableFields) {
schemas.map { case (colFamilyName, stateStoreColFamilySchema) =>
colFamilyName -> stateStoreColFamilySchema.copy(
valueSchema = stateStoreColFamilySchema.valueSchema.toNullable
)
schemas.map { case (colFamilyName, schema) =>
// assert that each field is nullable if schema evolution is enabled
schema.valueSchema.fields.foreach { field =>
if (!field.nullable && shouldCheckNullable && !isInternal(colFamilyName)) {
ericm-db marked this conversation as resolved.
Show resolved Hide resolved
throw StateStoreErrors.stateStoreSchemaMustBeNullable(
schema.colFamilyName, schema.valueSchema.toString())
}
}
} else {
schemas
colFamilyName -> schema.copy(
ericm-db marked this conversation as resolved.
Show resolved Hide resolved
valueSchema = schema.valueSchema.toNullable
)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ case class TransformWithStateExec(
Some(NoPrefixKeyStateEncoderSpec(keySchema)))

val columnFamilySchemas = getDriverProcessorHandle()
.getColumnFamilySchemas(setNullableFields) ++
.getColumnFamilySchemas(false) ++
ericm-db marked this conversation as resolved.
Show resolved Hide resolved
Map(StateStore.DEFAULT_COL_FAMILY_NAME -> defaultSchema)
closeProcessorHandle()
columnFamilySchemas
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ object StateStoreErrors {
new StateStoreValueSchemaNotCompatible(storedValueSchema, newValueSchema)
}

def stateStoreSchemaMustBeNullable(
columnFamilyName: String,
schema: String): StateStoreSchemaMustBeNullable = {
new StateStoreSchemaMustBeNullable(columnFamilyName, schema)
}

def stateStoreInvalidValueSchemaEvolution(
oldValueSchema: String,
newValueSchema: String): StateStoreInvalidValueSchemaEvolution = {
Expand Down Expand Up @@ -346,6 +352,15 @@ class StateStoreValueSchemaNotCompatible(
"storedValueSchema" -> storedValueSchema,
"newValueSchema" -> newValueSchema))

class StateStoreSchemaMustBeNullable(
columnFamilyName: String,
schema: String)
extends SparkUnsupportedOperationException(
errorClass = "STATE_STORE_SCHEMA_MUST_BE_NULLABLE",
messageParameters = Map(
"columnFamilyName" -> columnFamilyName,
"schema" -> schema))

class StateStoreInvalidValueSchemaEvolution(
oldValueSchema: String,
newValueSchema: String)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1815,7 +1815,6 @@ class TransformWithStateSuite extends StateStoreMetricsTest
TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
withTempDir { checkpointDir =>
// When Avro is used, we want to set the StructFields to nullable
val shouldBeNullable = usingAvroEncoding()
ericm-db marked this conversation as resolved.
Show resolved Hide resolved
val metadataPathPostfix = "state/0/_stateSchema/default"
val stateSchemaPath = new Path(checkpointDir.toString,
s"$metadataPathPostfix")
Expand All @@ -1826,15 +1825,15 @@ class TransformWithStateSuite extends StateStoreMetricsTest
val schema0 = StateStoreColFamilySchema(
"countState", 0,
keySchema, 0,
new StructType().add("value", LongType, nullable = shouldBeNullable),
ericm-db marked this conversation as resolved.
Show resolved Hide resolved
new StructType().add("value", LongType, nullable = true),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)
val schema1 = StateStoreColFamilySchema(
"listState", 0,
keySchema, 0,
new StructType()
.add("id", LongType, nullable = shouldBeNullable)
.add("id", LongType, nullable = true)
.add("name", StringType),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
Expand All @@ -1857,7 +1856,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest
val schema3 = StateStoreColFamilySchema(
"$rowCounter_listState", 0,
keySchema, 0,
new StructType().add("count", LongType, nullable = shouldBeNullable),
new StructType().add("count", LongType, nullable = true),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
) {
withTempDir { checkpointDir =>
// When Avro is used, we want to set the StructFields to nullable
val shouldBeNullable = usingAvroEncoding()
val metadataPathPostfix = "state/0/_stateSchema/default"
val stateSchemaPath = new Path(checkpointDir.toString, s"$metadataPathPostfix")
val hadoopConf = spark.sessionState.newHadoopConf()
Expand Down Expand Up @@ -317,15 +316,15 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
val schema2 = StateStoreColFamilySchema(
"$count_listState", 0,
keySchema, 0,
new StructType().add("count", LongType, nullable = shouldBeNullable),
new StructType().add("count", LongType, nullable = true),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)

val schema3 = StateStoreColFamilySchema(
"$rowCounter_listState", 0,
keySchema, 0,
new StructType().add("count", LongType, nullable = shouldBeNullable),
new StructType().add("count", LongType, nullable = true),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)
Expand Down Expand Up @@ -409,7 +408,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
"valueStateTTL", 0,
keySchema, 0,
new StructType()
.add("value", new StructType().add("value", IntegerType, nullable = shouldBeNullable))
.add("value", new StructType().add("value", IntegerType, nullable = true))
.add("ttlExpirationMs", LongType),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
Expand All @@ -418,7 +417,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
val schema10 = StateStoreColFamilySchema(
"valueState", 0,
keySchema, 0,
new StructType().add("value", IntegerType, nullable = shouldBeNullable),
new StructType().add("value", IntegerType, nullable = true),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)
Expand All @@ -428,7 +427,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
keySchema, 0,
new StructType()
.add("value", new StructType()
.add("id", LongType, nullable = shouldBeNullable)
.add("id", LongType, nullable = true)
.add("name", StringType))
.add("ttlExpirationMs", LongType),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
Expand Down