Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-51065][SQL] Disallowing non-nullable schema when Avro encoding is used for TransformWithState #49751

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -5072,6 +5072,14 @@
],
"sqlState" : "42601"
},
"TRANSFORM_WITH_STATE_SCHEMA_MUST_BE_NULLABLE" : {
"message" : [
"If Avro encoding is enabled, all the fields in the schema for column family <columnFamilyName> must be nullable",
"when using the TransformWithState operator.",
"Please make the schema nullable. Current schema: <schema>"
],
"sqlState" : "XXKST"
},
"TRANSPOSE_EXCEED_ROW_LIMIT" : {
"message" : [
"Number of rows exceeds the allowed limit of <maxValues> for TRANSPOSE. If this was intended, set <config> to at least the current row count."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1470,6 +1470,39 @@ def check_exception(error):
check_exception=check_exception,
)

def test_not_nullable_fails(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not having identical test in Scala as well? I don't see a new test verifying the error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is, there is no way for user to specify this using Scala.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes and probably also no.

I agree moderate users may not ever try to get over and just stick with case class or POJO or so. But "we" can imagine a way to get over, exactly the same way how we could support PySpark:

override protected val stateEncoder: ExpressionEncoder[Any] =
    ExpressionEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]]

This is how we come up with state encoder for Python version of FMGWS. This is to serde with Row interface - my rough memory says it's not InternalRow but Row, so, most likely work with GenericRow, but we can try with both GenericRow and InternalRow.

I'm OK with deferring this as follow-up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I get that you are not able to test this actually, as we have to just accept non-nullable column and change to nullable. Again I doubt this is just a bug though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Not really a bug, we figured out.)

with self.sql_conf({"spark.sql.streaming.stateStore.encodingFormat": "avro"}):
with tempfile.TemporaryDirectory() as checkpoint_dir:
input_path = tempfile.mkdtemp()
self._prepare_test_resource1(input_path)

df = self._build_test_df(input_path)

def check_basic_state(batch_df, batch_id):
result = batch_df.collect()[0]
assert result.value["id"] == 0 # First ID from test data
assert result.value["name"] == "name-0"

def check_exception(error):
from pyspark.errors.exceptions.captured import StreamingQueryException

if not isinstance(error, StreamingQueryException):
return False

error_msg = str(error)
return (
"[TRANSFORM_WITH_STATE_SCHEMA_MUST_BE_NULLABLE]" in error_msg
and "column family state must be nullable" in error_msg
)

self._run_evolution_test(
BasicProcessorNotNullable(),
checkpoint_dir,
check_basic_state,
df,
check_exception=check_exception,
)


class SimpleStatefulProcessorWithInitialState(StatefulProcessor):
# this dict is the same as input initial state dataframe
Expand Down Expand Up @@ -1893,6 +1926,27 @@ def close(self) -> None:
pass


class BasicProcessorNotNullable(StatefulProcessor):
# Schema definitions
state_schema = StructType(
[StructField("id", IntegerType(), False), StructField("name", StringType(), False)]
)

def init(self, handle):
self.state = handle.getValueState("state", self.state_schema)

def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
for pdf in rows:
pass
id_val = int(key[0])
name = f"name-{id_val}"
self.state.update((id_val, name))
yield pd.DataFrame({"id": [key[0]], "value": [{"id": id_val, "name": name}]})

def close(self) -> None:
pass


class AddFieldsProcessor(StatefulProcessor):
state_schema = StructType(
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,14 @@ case class TransformWithStateInPandasExec(
override def operatorStateMetadataVersion: Int = 2

override def getColFamilySchemas(
setNullableFields: Boolean
shouldBeNullable: Boolean
): Map[String, StateStoreColFamilySchema] = {
driverProcessorHandle.getColumnFamilySchemas(setNullableFields)
// For Python, the user can explicitly set nullability on schema, so
// we need to throw an error if the schema is nullable
driverProcessorHandle.getColumnFamilySchemas(
shouldCheckNullable = shouldBeNullable,
shouldSetNullable = shouldBeNullable
)
}

override def getStateVariableInfos(): Map[String, TransformWithStateVariableInfo] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ trait ListStateMetricsImpl {
// We keep track of the count of entries in the list in a separate column family
// to avoid scanning the entire list to get the count.
private val counterCFValueSchema: StructType =
StructType(Seq(StructField("count", LongType, nullable = false)))
StructType(Seq(StructField("count", LongType, nullable = true)))

private val counterCFProjection = UnsafeProjection.create(counterCFValueSchema)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ object StateStoreColumnFamilySchemaUtils {
// Byte type is converted to Int in Avro, which doesn't work for us as Avro
// uses zig-zag encoding as opposed to big-endian for Ints
Seq(
StructField(s"${field.name}_marker", BinaryType, nullable = false),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets say nullable=true explicitly ?

StructField(s"${field.name}_marker", BinaryType, nullable = true),
field.copy(name = s"${field.name}_value", BinaryType)
)
} else {
Expand Down Expand Up @@ -117,7 +117,7 @@ object StateStoreColumnFamilySchemaUtils {
getRowCounterCFName(stateName), keySchemaId = 0,
keyEncoder.schema,
valueSchemaId = 0,
StructType(Seq(StructField("count", LongType, nullable = false))),
StructType(Seq(StructField("count", LongType, nullable = true))),
Some(NoPrefixKeyStateEncoderSpec(keyEncoder.schema)))
schemas.put(counterSchema.colFamilyName, counterSchema)

Expand Down Expand Up @@ -149,7 +149,7 @@ object StateStoreColumnFamilySchemaUtils {
keySchemaId = 0,
keyEncoder.schema,
valueSchemaId = 0,
StructType(Seq(StructField("count", LongType, nullable = false))),
StructType(Seq(StructField("count", LongType, nullable = true))),
Some(NoPrefixKeyStateEncoderSpec(keyEncoder.schema)))
schemas.put(countSchema.colFamilyName, countSchema)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,18 +363,38 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
addTimerColFamily()
}

/**
* This method returns all column family schemas, and checks and enforces nullability
* if need be. The nullability check and set is only set to true when Avro is enabled.
* @param shouldCheckNullable Whether we need to check the nullability. This is set to
* true when using Python, as this is the only avenue through
* which users can set nullability
* @param shouldSetNullable Whether we need to set the fields as nullable. This is set to
* true when using Scala, as primitive type encoders set the field
* to non-nullable. Changing fields from non-nullable to nullable
* does not break anything (and is required for Avro encoding), so
* we can safely make this change.
* @return column family schemas used by this stateful processor.
*/
def getColumnFamilySchemas(
setNullableFields: Boolean
shouldCheckNullable: Boolean,
shouldSetNullable: Boolean
): Map[String, StateStoreColFamilySchema] = {
val schemas = columnFamilySchemas.toMap
if (setNullableFields) {
schemas.map { case (colFamilyName, stateStoreColFamilySchema) =>
colFamilyName -> stateStoreColFamilySchema.copy(
valueSchema = stateStoreColFamilySchema.valueSchema.toNullable
schemas.map { case (colFamilyName, schema) =>
schema.valueSchema.fields.foreach { field =>
if (!field.nullable && shouldCheckNullable) {
throw StateStoreErrors.twsSchemaMustBeNullable(
schema.colFamilyName, schema.valueSchema.toString())
}
}
if (shouldSetNullable) {
colFamilyName -> schema.copy(
valueSchema = schema.valueSchema.toNullable
)
} else {
colFamilyName -> schema
}
} else {
schemas
}
}

Expand Down Expand Up @@ -549,7 +569,7 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
elementKeySchema: StructType): StateStoreColFamilySchema = {
val countIndexName = s"$$count_$stateName"
val countValueSchema = StructType(Seq(
StructField("count", LongType, nullable = false)
StructField("count", LongType)
))

StateStoreColFamilySchema(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ abstract class OneToManyTTLState(
// Schema of the entry count index: elementKey -> count
private val COUNT_INDEX = "$count_" + stateName
private val COUNT_INDEX_VALUE_SCHEMA: StructType =
StructType(Seq(StructField("count", LongType, nullable = false)))
StructType(Seq(StructField("count", LongType)))
private val countIndexValueProjector = UnsafeProjection.create(COUNT_INDEX_VALUE_SCHEMA)

// Reused internal row that we use to create an UnsafeRow with the schema of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ case class TransformWithStateExec(
* after init is called.
*/
override def getColFamilySchemas(
setNullableFields: Boolean
shouldBeNullable: Boolean
): Map[String, StateStoreColFamilySchema] = {
val keySchema = keyExpressions.toStructType
// we have to add the default column family schema because the RocksDBStateEncoder
Expand All @@ -149,8 +149,11 @@ case class TransformWithStateExec(
0, keyExpressions.toStructType, 0, DUMMY_VALUE_ROW_SCHEMA,
Some(NoPrefixKeyStateEncoderSpec(keySchema)))

// For Scala, the user can't explicitly set nullability on schema, so there is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise I mentioned in other comment, it is not impossible to set nullability on encoder (although I tend to agree most users won't). Let's not make this be conditional.

Also, this is concerning me - if we are very confident that users would never be able to set column to be nullable, why we need to change the schema as we all know it has to be nullable? What we are worrying about if we just do the same with Python?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#49751 (comment)
I realized you had to go through this way due to case class enconder. Sorry about that.

// no reason to throw an error, and we can simply set the schema to nullable.
val columnFamilySchemas = getDriverProcessorHandle()
.getColumnFamilySchemas(setNullableFields) ++
.getColumnFamilySchemas(
shouldCheckNullable = false, shouldSetNullable = shouldBeNullable) ++
Map(StateStore.DEFAULT_COL_FAMILY_NAME -> defaultSchema)
closeProcessorHandle()
columnFamilySchemas
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ object TransformWithStateOperatorProperties extends Logging {
*/
trait TransformWithStateMetadataUtils extends Logging {

def getColFamilySchemas(setNullableFields: Boolean): Map[String, StateStoreColFamilySchema]
// This method will return the column family schemas, and check whether the fields in the
// schema are nullable. If Avro encoding is used, we want to enforce nullability
def getColFamilySchemas(shouldBeNullable: Boolean): Map[String, StateStoreColFamilySchema]

def getStateVariableInfos(): Map[String, TransformWithStateVariableInfo]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ object StateStoreErrors {
new StateStoreValueSchemaNotCompatible(storedValueSchema, newValueSchema)
}

def twsSchemaMustBeNullable(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think TWS deserves its own error collection class, but I agree this is out of scope. Let's make a follow-up.

columnFamilyName: String,
schema: String): TWSSchemaMustBeNullable = {
new TWSSchemaMustBeNullable(columnFamilyName, schema)
}

def stateStoreInvalidValueSchemaEvolution(
oldValueSchema: String,
newValueSchema: String): StateStoreInvalidValueSchemaEvolution = {
Expand Down Expand Up @@ -346,6 +352,15 @@ class StateStoreValueSchemaNotCompatible(
"storedValueSchema" -> storedValueSchema,
"newValueSchema" -> newValueSchema))

class TWSSchemaMustBeNullable(
columnFamilyName: String,
schema: String)
extends SparkUnsupportedOperationException(
errorClass = "TRANSFORM_WITH_STATE_SCHEMA_MUST_BE_NULLABLE",
messageParameters = Map(
"columnFamilyName" -> columnFamilyName,
"schema" -> schema))

class StateStoreInvalidValueSchemaEvolution(
oldValueSchema: String,
newValueSchema: String)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,7 @@ abstract class TransformWithStateSuite extends StateStoreMetricsTest
val schema3 = StateStoreColFamilySchema(
"$rowCounter_listState", 0,
keySchema, 0,
new StructType().add("count", LongType, nullable = shouldBeNullable),
new StructType().add("count", LongType, nullable = true),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,15 +317,15 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
val schema2 = StateStoreColFamilySchema(
"$count_listState", 0,
keySchema, 0,
new StructType().add("count", LongType, nullable = shouldBeNullable),
new StructType().add("count", LongType, nullable = true),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)

val schema3 = StateStoreColFamilySchema(
"$rowCounter_listState", 0,
keySchema, 0,
new StructType().add("count", LongType, nullable = shouldBeNullable),
new StructType().add("count", LongType, nullable = true),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)
Expand Down