Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-51065][SQL] Disallowing non-nullable schema when Avro encoding is used for TransformWithState #49751

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -5072,6 +5072,14 @@
],
"sqlState" : "42601"
},
"TRANSFORM_WITH_STATE_SCHEMA_MUST_BE_NULLABLE" : {
"message" : [
"If schema evolution is enabled, all the fields in the schema for column family <columnFamilyName> must be nullable",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we think whichever is easier to understand, "using Avro" or "schema evolution is enabled"?

I foresee the direction of using Avro for all stateful operators (unless there is outstanding regression), and once we make Avro by default, this will be confusing one to consume because they don't do anything for schema evolution. IMO it is "indirect" information and they would probably try to figure out how to disable schema evolution instead, without knowing that Avro and schema evolution is coupled.

cc. @anishshri-db to hear his voice.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I think its fine to say that we refer to the transformWithState case relative to Avro being used - dont need to explicitly call out schema evolution here

" when using the TransformWithState operator.",
" Please make the schema nullable. Current schema: <schema>"
],
"sqlState" : "XXKST"
},
"TRANSPOSE_EXCEED_ROW_LIMIT" : {
"message" : [
"Number of rows exceeds the allowed limit of <maxValues> for TRANSPOSE. If this was intended, set <config> to at least the current row count."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1470,6 +1470,39 @@ def check_exception(error):
check_exception=check_exception,
)

def test_not_nullable_fails(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not having identical test in Scala as well? I don't see a new test verifying the error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is, there is no way for user to specify this using Scala.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes and probably also no.

I agree moderate users may not ever try to get over and just stick with case class or POJO or so. But "we" can imagine a way to get over, exactly the same way how we could support PySpark:

override protected val stateEncoder: ExpressionEncoder[Any] =
    ExpressionEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]]

This is how we come up with state encoder for Python version of FMGWS. This is to serde with Row interface - my rough memory says it's not InternalRow but Row, so, most likely work with GenericRow, but we can try with both GenericRow and InternalRow.

I'm OK with deferring this as follow-up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I get that you are not able to test this actually, as we have to just accept non-nullable column and change to nullable. Again I doubt this is just a bug though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Not really a bug, we figured out.)

with self.sql_conf({"spark.sql.streaming.stateStore.encodingFormat": "avro"}):
with tempfile.TemporaryDirectory() as checkpoint_dir:
input_path = tempfile.mkdtemp()
self._prepare_test_resource1(input_path)

df = self._build_test_df(input_path)

def check_basic_state(batch_df, batch_id):
result = batch_df.collect()[0]
assert result.value["id"] == 0 # First ID from test data
assert result.value["name"] == "name-0"

def check_exception(error):
from pyspark.errors.exceptions.captured import StreamingQueryException

if not isinstance(error, StreamingQueryException):
return False

error_msg = str(error)
return (
"[TRANSFORM_WITH_STATE_SCHEMA_MUST_BE_NULLABLE]" in error_msg
and "column family state must be nullable" in error_msg
)

self._run_evolution_test(
BasicProcessorNotNullable(),
checkpoint_dir,
check_basic_state,
df,
check_exception=check_exception,
)


class SimpleStatefulProcessorWithInitialState(StatefulProcessor):
# this dict is the same as input initial state dataframe
Expand Down Expand Up @@ -1893,6 +1926,27 @@ def close(self) -> None:
pass


class BasicProcessorNotNullable(StatefulProcessor):
# Schema definitions
state_schema = StructType(
[StructField("id", IntegerType(), False), StructField("name", StringType(), False)]
)

def init(self, handle):
self.state = handle.getValueState("state", self.state_schema)

def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
for pdf in rows:
pass
id_val = int(key[0])
name = f"name-{id_val}"
self.state.update((id_val, name))
yield pd.DataFrame({"id": [key[0]], "value": [{"id": id_val, "name": name}]})

def close(self) -> None:
pass


class AddFieldsProcessor(StatefulProcessor):
state_schema = StructType(
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,14 @@ case class TransformWithStateInPandasExec(
override def operatorStateMetadataVersion: Int = 2

override def getColFamilySchemas(
setNullableFields: Boolean
shouldBeNullable: Boolean
): Map[String, StateStoreColFamilySchema] = {
driverProcessorHandle.getColumnFamilySchemas(setNullableFields)
// For Python, the user can explicitly set nullability on schema, so
// we need to throw an error if the schema is nullable
driverProcessorHandle.getColumnFamilySchemas(
shouldCheckNullable = shouldBeNullable,
shouldSetNullable = shouldBeNullable
)
}

override def getStateVariableInfos(): Map[String, TransformWithStateVariableInfo] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ trait ListStateMetricsImpl {
// We keep track of the count of entries in the list in a separate column family
// to avoid scanning the entire list to get the count.
private val counterCFValueSchema: StructType =
StructType(Seq(StructField("count", LongType, nullable = false)))
StructType(Seq(StructField("count", LongType, nullable = true)))

private val counterCFProjection = UnsafeProjection.create(counterCFValueSchema)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ object StateStoreColumnFamilySchemaUtils {
// Byte type is converted to Int in Avro, which doesn't work for us as Avro
// uses zig-zag encoding as opposed to big-endian for Ints
Seq(
StructField(s"${field.name}_marker", BinaryType, nullable = false),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets say nullable=true explicitly ?

StructField(s"${field.name}_marker", BinaryType, nullable = true),
field.copy(name = s"${field.name}_value", BinaryType)
)
} else {
Expand Down Expand Up @@ -117,7 +117,7 @@ object StateStoreColumnFamilySchemaUtils {
getRowCounterCFName(stateName), keySchemaId = 0,
keyEncoder.schema,
valueSchemaId = 0,
StructType(Seq(StructField("count", LongType, nullable = false))),
StructType(Seq(StructField("count", LongType, nullable = true))),
Some(NoPrefixKeyStateEncoderSpec(keyEncoder.schema)))
schemas.put(counterSchema.colFamilyName, counterSchema)

Expand Down Expand Up @@ -149,7 +149,7 @@ object StateStoreColumnFamilySchemaUtils {
keySchemaId = 0,
keyEncoder.schema,
valueSchemaId = 0,
StructType(Seq(StructField("count", LongType, nullable = false))),
StructType(Seq(StructField("count", LongType, nullable = true))),
Some(NoPrefixKeyStateEncoderSpec(keyEncoder.schema)))
schemas.put(countSchema.colFamilyName, countSchema)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,18 +363,36 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
addTimerColFamily()
}

/**
* This method returns all column family schemas, and checks and enforces nullability
* if need be. The nullability check and set is only set to true when Avro is enabled.
* @param shouldCheckNullable Whether we need to check the nullability. This is set to
* true when using Python, as this is the only avenue through
* which users can set nullability
* @param shouldSetNullable Whether we need to set the fields as nullable. This is set to
* true when using Scala, as case classes are set to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

case classes are set to non-nullable by default.

I'm actually surprised and it sounds like a bug to me. (Sorry, you had to handle Python and Scala differently due to this. My bad.)

What if you set null to any of fields in case class? Will it work, and if it works, how?

If this is indeed a bug and we can fix that, then we can simplify things a lot. I'm OK if you want to defer this, but definitely need to have follow up ticket for this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it is only true for primitive type? If then it might make sense, like an optimization for the type which could never have null. If you see non-nullable to String or so, this should be a bug.

Copy link
Contributor

@HeartSaVioR HeartSaVioR Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found this, def nullable: Boolean = !isPrimitive from AgnosticEncoder trait, so it's intended and not a bug. For the track I've asked to update the comment since it is not always non-nullable, but primitive types are nullable.

* non-nullable by default.
* @return column family schemas used by this stateful processor.
*/
def getColumnFamilySchemas(
setNullableFields: Boolean
shouldCheckNullable: Boolean,
shouldSetNullable: Boolean
): Map[String, StateStoreColFamilySchema] = {
val schemas = columnFamilySchemas.toMap
if (setNullableFields) {
schemas.map { case (colFamilyName, stateStoreColFamilySchema) =>
colFamilyName -> stateStoreColFamilySchema.copy(
valueSchema = stateStoreColFamilySchema.valueSchema.toNullable
schemas.map { case (colFamilyName, schema) =>
schema.valueSchema.fields.foreach { field =>
if (!field.nullable && shouldCheckNullable) {
throw StateStoreErrors.twsSchemaMustBeNullable(
schema.colFamilyName, schema.valueSchema.toString())
}
}
if (shouldSetNullable) {
colFamilyName -> schema.copy(
valueSchema = schema.valueSchema.toNullable
)
} else {
colFamilyName -> schema
}
} else {
schemas
}
}

Expand Down Expand Up @@ -549,7 +567,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
elementKeySchema: StructType): StateStoreColFamilySchema = {
val countIndexName = s"$$count_$stateName"
val countValueSchema = StructType(Seq(
StructField("count", LongType, nullable = false)
StructField("count", LongType)
))

StateStoreColFamilySchema(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ abstract class OneToManyTTLState(
// Schema of the entry count index: elementKey -> count
private val COUNT_INDEX = "$count_" + stateName
private val COUNT_INDEX_VALUE_SCHEMA: StructType =
StructType(Seq(StructField("count", LongType, nullable = false)))
StructType(Seq(StructField("count", LongType)))
private val countIndexValueProjector = UnsafeProjection.create(COUNT_INDEX_VALUE_SCHEMA)

// Reused internal row that we use to create an UnsafeRow with the schema of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ case class TransformWithStateExec(
* after init is called.
*/
override def getColFamilySchemas(
setNullableFields: Boolean
shouldBeNullable: Boolean
): Map[String, StateStoreColFamilySchema] = {
val keySchema = keyExpressions.toStructType
// we have to add the default column family schema because the RocksDBStateEncoder
Expand All @@ -149,8 +149,11 @@ case class TransformWithStateExec(
0, keyExpressions.toStructType, 0, DUMMY_VALUE_ROW_SCHEMA,
Some(NoPrefixKeyStateEncoderSpec(keySchema)))

// For Scala, the user can't explicitly set nullability on schema, so there is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise I mentioned in other comment, it is not impossible to set nullability on encoder (although I tend to agree most users won't). Let's not make this be conditional.

Also, this is concerning me - if we are very confident that users would never be able to set column to be nullable, why we need to change the schema as we all know it has to be nullable? What we are worrying about if we just do the same with Python?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#49751 (comment)
I realized you had to go through this way due to case class enconder. Sorry about that.

// no reason to throw an error, and we can simply set the schema to nullable.
val columnFamilySchemas = getDriverProcessorHandle()
.getColumnFamilySchemas(setNullableFields) ++
.getColumnFamilySchemas(
shouldCheckNullable = false, shouldSetNullable = shouldBeNullable) ++
Map(StateStore.DEFAULT_COL_FAMILY_NAME -> defaultSchema)
closeProcessorHandle()
columnFamilySchemas
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ object TransformWithStateOperatorProperties extends Logging {
*/
trait TransformWithStateMetadataUtils extends Logging {

def getColFamilySchemas(setNullableFields: Boolean): Map[String, StateStoreColFamilySchema]
// This method will return the column family schemas, and check whether the fields in the
// schema are nullable. If Avro encoding is used, we want to enforce nullability
def getColFamilySchemas(shouldBeNullable: Boolean): Map[String, StateStoreColFamilySchema]

def getStateVariableInfos(): Map[String, TransformWithStateVariableInfo]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ object StateStoreErrors {
new StateStoreValueSchemaNotCompatible(storedValueSchema, newValueSchema)
}

def twsSchemaMustBeNullable(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think TWS deserves its own error collection class, but I agree this is out of scope. Let's make a follow-up.

columnFamilyName: String,
schema: String): TWSSchemaMustBeNullable = {
new TWSSchemaMustBeNullable(columnFamilyName, schema)
}

def stateStoreInvalidValueSchemaEvolution(
oldValueSchema: String,
newValueSchema: String): StateStoreInvalidValueSchemaEvolution = {
Expand Down Expand Up @@ -346,6 +352,15 @@ class StateStoreValueSchemaNotCompatible(
"storedValueSchema" -> storedValueSchema,
"newValueSchema" -> newValueSchema))

class TWSSchemaMustBeNullable(
columnFamilyName: String,
schema: String)
extends SparkUnsupportedOperationException(
errorClass = "TRANSFORM_WITH_STATE_SCHEMA_MUST_BE_NULLABLE",
messageParameters = Map(
"columnFamilyName" -> columnFamilyName,
"schema" -> schema))

class StateStoreInvalidValueSchemaEvolution(
oldValueSchema: String,
newValueSchema: String)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
) {
withTempDir { checkpointDir =>
// When Avro is used, we want to set the StructFields to nullable
val shouldBeNullable = usingAvroEncoding()
val metadataPathPostfix = "state/0/_stateSchema/default"
val stateSchemaPath = new Path(checkpointDir.toString, s"$metadataPathPostfix")
val hadoopConf = spark.sessionState.newHadoopConf()
Expand Down Expand Up @@ -317,15 +316,15 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
val schema2 = StateStoreColFamilySchema(
"$count_listState", 0,
keySchema, 0,
new StructType().add("count", LongType, nullable = shouldBeNullable),
new StructType().add("count", LongType, nullable = true),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)

val schema3 = StateStoreColFamilySchema(
"$rowCounter_listState", 0,
keySchema, 0,
new StructType().add("count", LongType, nullable = shouldBeNullable),
new StructType().add("count", LongType, nullable = true),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)
Expand Down Expand Up @@ -409,7 +408,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
"valueStateTTL", 0,
keySchema, 0,
new StructType()
.add("value", new StructType().add("value", IntegerType, nullable = shouldBeNullable))
.add("value", new StructType().add("value", IntegerType, nullable = true))
.add("ttlExpirationMs", LongType),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
Expand All @@ -418,7 +417,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
val schema10 = StateStoreColFamilySchema(
"valueState", 0,
keySchema, 0,
new StructType().add("value", IntegerType, nullable = shouldBeNullable),
new StructType().add("value", IntegerType, nullable = true),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)
Expand All @@ -428,7 +427,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
keySchema, 0,
new StructType()
.add("value", new StructType()
.add("id", LongType, nullable = shouldBeNullable)
.add("id", LongType, nullable = true)
.add("name", StringType))
.add("ttlExpirationMs", LongType),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
Expand Down