Skip to content

Commit

Permalink
nested not null ignored if parent is null
Browse files Browse the repository at this point in the history
  • Loading branch information
harperjiang committed Feb 4, 2025
1 parent a6db4e0 commit d2302b1
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@

package org.apache.spark.sql.delta.constraints

import scala.collection.mutable

import org.apache.spark.sql.delta.DeltaErrors
import org.apache.spark.sql.delta.schema.SchemaUtils
import org.apache.spark.sql.delta.util.JsonUtils

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataType, StructField, StructType}


/**
* List of invariants that can be defined on a Delta table that will allow us to perform
Expand Down Expand Up @@ -71,13 +73,52 @@ object Invariants {

/** Extract invariants from the given schema */
def getFromSchema(schema: StructType, spark: SparkSession): Seq[Constraint] = {
val columns = SchemaUtils.filterRecursively(schema, checkComplexTypes = false) { field =>
!field.nullable || field.metadata.contains(INVARIANTS_FIELD)
/**
* Find the fields containing constraints, as well as its nearest nullable ancestor
* @return (parent path, the nearest null ancestor idx, field)
*/
def recursiveVisitSchema(
columnPath: Seq[String],
dataType: DataType,
nullableAncestorIdxs: mutable.Buffer[Int]): Seq[(Seq[String], Int, StructField)] = {
dataType match {
case st: StructType =>
st.fields.toList.flatMap { field =>
val includeLevel = if (field.metadata.contains(INVARIANTS_FIELD) || !field.nullable) {
Seq((
columnPath,
if (nullableAncestorIdxs.isEmpty) -1 else nullableAncestorIdxs.last,
field
))
} else {
Nil
}
if (field.nullable) {
nullableAncestorIdxs.append(columnPath.size)
}
val childResults = recursiveVisitSchema(
columnPath :+ field.name, field.dataType, nullableAncestorIdxs)
if (field.nullable) {
nullableAncestorIdxs.trimEnd(1)
}
includeLevel ++ childResults
}
case _ => Nil
}
}
columns.map {
case (parents, field) if !field.nullable =>
Constraints.NotNull(parents :+ field.name)
case (parents, field) =>

recursiveVisitSchema(Nil, schema, new mutable.ArrayBuffer[Int]()).map {
case (parents, nullableAncestor, field) if !field.nullable =>
val fieldPath: Seq[String] = parents :+ field.name
if (nullableAncestor != -1) {
Constraints.Check("",
ArbitraryExpression(spark,
s"${parents.take(nullableAncestor + 1).mkString(".")} is null " +
s"or ${fieldPath.mkString(".")} is not null").expression)
} else {
Constraints.NotNull(fieldPath)
}
case (parents, _, field) =>
val rule = field.metadata.getString(INVARIANTS_FIELD)
val invariant = Option(JsonUtils.mapper.readValue[PersistedRule](rule).unwrap) match {
case Some(PersistedExpression(exprString)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ import org.apache.spark.sql.delta.{CheckConstraintsTableFeature, DeltaLog, Delta
import org.apache.spark.sql.delta.actions.{Metadata, TableFeatureProtocolUtils}
import org.apache.spark.sql.delta.catalog.DeltaTableV2
import org.apache.spark.sql.delta.constraints.{Constraint, Constraints, Invariants}
import org.apache.spark.sql.delta.constraints.Constraints.NotNull
import org.apache.spark.sql.delta.constraints.Invariants.PersistedExpression
import org.apache.spark.sql.delta.constraints.Constraints.{Check, NotNull}
import org.apache.spark.sql.delta.constraints.Invariants.{ArbitraryExpression, PersistedExpression}
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.test.DeltaSQLCommandTest
import org.apache.spark.sql.delta.test.DeltaSQLTestUtils
Expand Down Expand Up @@ -192,17 +192,15 @@ class InvariantEnforcementSuite extends QueryTest
.add("key", StringType, nullable = false)
.add("value", IntegerType))
testBatchWriteRejection(
NotNull(Seq("key")),
Check("", ArbitraryExpression(spark, "top is null or top.key is not null").expression),
schema,
spark.createDataFrame(Seq(Row(Row("a", 1)), Row(Row(null, 2))).asJava, schema.asNullable),
"top.key"
)
testBatchWriteRejection(
NotNull(Seq("key")),
schema,
spark.createDataFrame(Seq(Row(Row("a", 1)), Row(null)).asJava, schema.asNullable),
"top.key"
)
tableWithSchema(schema) { path =>
spark.createDataFrame(Seq(Row(Row("a", 1)), Row(null)).asJava, schema.asNullable)
.write.mode("append").format("delta").save(path)
}
}

testQuietly("reject non-nullable array column") {
Expand Down

0 comments on commit d2302b1

Please sign in to comment.