Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ import org.apache.gluten.execution.{DeltaScanTransformer, ProjectExecTransformer
import org.apache.gluten.extension.columnar.transition.RemoveTransitions

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Cast, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.delta.{DeltaColumnMapping, DeltaParquetFileFormat, NoMapping}
import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}

import scala.collection.mutable
import scala.collection.mutable.ListBuffer

object DeltaPostTransformRules {
Expand Down Expand Up @@ -93,6 +95,26 @@ object DeltaPostTransformRules {
}
}

/**
* Checks whether two structurally compatible DataTypes have different struct field names at any
* nesting level.
*/
private def structFieldNamesDiffer(logical: DataType, physical: DataType): Boolean = {
(logical, physical) match {
case (l: StructType, p: StructType) if l.length == p.length =>
l.zip(p).exists {
case (lf, pf) =>
lf.name != pf.name || structFieldNamesDiffer(lf.dataType, pf.dataType)
}
case (l: ArrayType, p: ArrayType) =>
structFieldNamesDiffer(l.elementType, p.elementType)
case (l: MapType, p: MapType) =>
structFieldNamesDiffer(l.keyType, p.keyType) ||
structFieldNamesDiffer(l.valueType, p.valueType)
case _ => false
}
}

/**
* This method is only used for Delta ColumnMapping FileFormat(e.g. nameMapping and idMapping)
* transform the metadata of Delta into Parquet's, each plan should only be transformed once.
Expand All @@ -115,8 +137,9 @@ object DeltaPostTransformRules {
)(SparkSession.active)
// transform output's name into physical name so Reader can read data correctly
// should keep the columns order the same as the origin output
val originColumnNames = ListBuffer.empty[String]
val transformedAttrs = ListBuffer.empty[Attribute]
case class ColumnMapping(logicalName: String, logicalType: DataType, physicalAttr: Attribute)
val columnMappings = ListBuffer.empty[ColumnMapping]
val seenNames = mutable.Set.empty[String]
def mapAttribute(attr: Attribute) = {
val newAttr = if (plan.isMetadataColumn(attr)) {
attr
Expand All @@ -127,9 +150,8 @@ object DeltaPostTransformRules {
.createPhysicalAttributes(Seq(attr), fmt.referenceSchema, fmt.columnMappingMode)
.head
}
if (!originColumnNames.contains(attr.name)) {
transformedAttrs += newAttr
originColumnNames += attr.name
if (seenNames.add(attr.name)) {
columnMappings += ColumnMapping(attr.name, attr.dataType, newAttr)
}
newAttr
}
Expand Down Expand Up @@ -169,9 +191,19 @@ object DeltaPostTransformRules {
scanExecTransformer.copyTagsFrom(plan)
tagColumnMappingRule(scanExecTransformer)

// alias physicalName into tableName
val expr = (transformedAttrs, originColumnNames).zipped.map {
(attr, columnName) => Alias(attr, columnName)(exprId = attr.exprId)
// Alias physical names back to logical names. For struct-typed columns, Delta column
// mapping renames internal field names to physical UUIDs. A top-level Alias only restores
// the column name, not the struct's internal field names. We add a Cast to the logical
// type so that downstream expressions see consistent names.
val expr = columnMappings.map {
cm =>
val projectedExpr: Expression =
if (structFieldNamesDiffer(cm.logicalType, cm.physicalAttr.dataType)) {
Cast(cm.physicalAttr, cm.logicalType)
} else {
cm.physicalAttr
}
Alias(projectedExpr, cm.logicalName)(exprId = cm.physicalAttr.exprId)
}
val projectExecTransformer = ProjectExecTransformer(expr.toSeq, scanExecTransformer)
projectExecTransformer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,4 +399,113 @@ abstract class DeltaSuite extends WholeStageTransformerSuite {
checkAnswer(df, Seq(Row(2), Row(3)))
}
}

testWithMinSparkVersion(
"merge with column mapping handles struct field metadata correctly",
"3.4") {
withTable("merge_struct_source", "merge_struct_target") {
spark.sql("""
|CREATE TABLE merge_struct_target(
| key INT NOT NULL,
| value INT NOT NULL,
| cstruct STRUCT<foo: INT>)
|USING DELTA
|TBLPROPERTIES (
| 'delta.minReaderVersion' = '2',
| 'delta.minWriterVersion' = '5',
| 'delta.columnMapping.mode' = 'name')
""".stripMargin)
spark.sql("INSERT INTO merge_struct_target VALUES (0, 0, null)")
spark.sql("INSERT INTO merge_struct_target VALUES (100, 100, named_struct('foo', 42))")

spark.sql(
"CREATE TABLE merge_struct_source (key INT NOT NULL, value INT NOT NULL) USING DELTA")
spark.sql("INSERT INTO merge_struct_source VALUES (1, 1)")

// MERGE with updateNotMatched to test CaseWhen else branch
spark.sql("""
|MERGE INTO merge_struct_target AS target
|USING merge_struct_source AS source
|ON source.key = target.key
|WHEN MATCHED THEN
| UPDATE SET target.value = source.value
|WHEN NOT MATCHED BY SOURCE AND target.key = 100 THEN
| UPDATE SET target.value = 22
""".stripMargin)

val df = runQueryAndCompare(
"SELECT key, value, cstruct FROM merge_struct_target ORDER BY key") { _ => }
checkAnswer(df, Row(0, 0, null) :: Row(100, 22, Row(42)) :: Nil)
}
}

testWithMinSparkVersion(
"merge with column mapping handles array-of-struct field metadata correctly",
"3.4") {
withTable("merge_arraystruct_source", "merge_arraystruct_target") {
spark.sql("""
|CREATE TABLE merge_arraystruct_target(
| key INT NOT NULL,
| tags ARRAY<STRUCT<label: STRING, score: INT>>)
|USING DELTA
|TBLPROPERTIES (
| 'delta.minReaderVersion' = '2',
| 'delta.minWriterVersion' = '5',
| 'delta.columnMapping.mode' = 'name')
""".stripMargin)
spark.sql("INSERT INTO merge_arraystruct_target VALUES (0, null)")
spark.sql(
"INSERT INTO merge_arraystruct_target VALUES " +
"(100, array(named_struct('label', 'a', 'score', 10)))")
spark.sql("CREATE TABLE merge_arraystruct_source (key INT NOT NULL) USING DELTA")
spark.sql("INSERT INTO merge_arraystruct_source VALUES (1)")
// MERGE that leaves the array-of-struct column unchanged via CaseWhen
spark.sql("""
|MERGE INTO merge_arraystruct_target AS target
|USING merge_arraystruct_source AS source
|ON source.key = target.key
|WHEN NOT MATCHED BY SOURCE AND target.key = 100 THEN
| UPDATE SET target.key = 101
""".stripMargin)
val df = runQueryAndCompare("SELECT key, tags FROM merge_arraystruct_target ORDER BY key") {
_ =>
}
checkAnswer(df, Row(0, null) :: Row(101, Seq(Row("a", 10))) :: Nil)
}
}

testWithMinSparkVersion(
"merge with column mapping handles map-of-struct field metadata correctly",
"3.4") {
withTable("merge_mapstruct_source", "merge_mapstruct_target") {
spark.sql("""
|CREATE TABLE merge_mapstruct_target(
| key INT NOT NULL,
| props MAP<STRING, STRUCT<val: INT>>)
|USING DELTA
|TBLPROPERTIES (
| 'delta.minReaderVersion' = '2',
| 'delta.minWriterVersion' = '5',
| 'delta.columnMapping.mode' = 'name')
""".stripMargin)
spark.sql("INSERT INTO merge_mapstruct_target VALUES (0, null)")
spark.sql(
"INSERT INTO merge_mapstruct_target VALUES " +
"(100, map('x', named_struct('val', 99)))")
spark.sql("CREATE TABLE merge_mapstruct_source (key INT NOT NULL) USING DELTA")
spark.sql("INSERT INTO merge_mapstruct_source VALUES (1)")
// MERGE that leaves the map-of-struct column unchanged via CaseWhen
spark.sql("""
|MERGE INTO merge_mapstruct_target AS target
|USING merge_mapstruct_source AS source
|ON source.key = target.key
|WHEN NOT MATCHED BY SOURCE AND target.key = 100 THEN
| UPDATE SET target.key = 101
""".stripMargin)
val df = runQueryAndCompare("SELECT key, props FROM merge_mapstruct_target ORDER BY key") {
_ =>
}
checkAnswer(df, Row(0, null) :: Row(101, Map("x" -> Row(99))) :: Nil)
}
}
}
Loading