diff --git a/gluten-iceberg/pom.xml b/gluten-iceberg/pom.xml index 51aa0d92bc34..366340aa695a 100644 --- a/gluten-iceberg/pom.xml +++ b/gluten-iceberg/pom.xml @@ -101,6 +101,12 @@ scalatest_${scala.binary.version} test + + org.mockito + mockito-core + 2.23.4 + test + diff --git a/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala b/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala index 8f6dcda7edff..082865da83dc 100644 --- a/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala +++ b/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala @@ -17,6 +17,7 @@ package org.apache.gluten.execution import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.config.GlutenConfig import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.execution.IcebergScanTransformer.{containsMetadataColumn, containsUuidOrFixedType} import org.apache.gluten.sql.shims.SparkShimLoader @@ -195,6 +196,16 @@ case class IcebergScanTransformer( metadataColumns ++ extraMetadataColumns } + @transient override protected lazy val finalPartitions: Seq[Partition] = { + if (keyGroupedPartitioning.isDefined) { + getFinalPartitions + } else { + GlutenIcebergSourceUtil.regeneratePartitions( + getFinalPartitions, + GlutenConfig.get.smallFileThreshold) + } + } + override lazy val fileFormat: ReadFileFormat = GlutenIcebergSourceUtil.getFileFormat(scan) override def getSplitInfosFromPartitions( diff --git a/gluten-iceberg/src/main/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtil.scala b/gluten-iceberg/src/main/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtil.scala index 1a658a78925e..13e8e01268f5 100644 --- a/gluten-iceberg/src/main/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtil.scala +++ b/gluten-iceberg/src/main/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtil.scala @@ -22,6 +22,9 @@ import org.apache.gluten.execution.SparkDataSourceRDDPartition import org.apache.gluten.substrait.rel.{IcebergLocalFilesBuilder, SplitInfo} import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat +import org.apache.spark.Partition +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.internal.Logging import org.apache.spark.softaffinity.SoftAffinity import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.connector.read.Scan @@ -29,14 +32,16 @@ import org.apache.spark.sql.types.StructType import org.apache.iceberg._ import org.apache.iceberg.spark.SparkSchemaUtil +import org.apache.iceberg.util.TableScanUtil import java.lang.{Class, Long => JLong} import java.util.{ArrayList => JArrayList, HashMap => JHashMap, List => JList, Map => JMap} import java.util.Locale import scala.collection.JavaConverters._ +import scala.collection.mutable -object GlutenIcebergSourceUtil { +object GlutenIcebergSourceUtil extends Logging { private val InputFileNameCol = "input_file_name" private val InputFileBlockStartCol = "input_file_block_start" private val InputFileBlockLengthCol = "input_file_block_length" @@ -234,4 +239,158 @@ object GlutenIcebergSourceUtil { case _ => throw new GlutenNotSupportException("Iceberg Only support parquet and orc file format.") } + + def regeneratePartitions( + inputPartitions: Seq[Partition], + smallFileThreshold: Double): Seq[Partition] = { + if (inputPartitions.isEmpty) { + return Seq.empty + } + + val icebergPartitions: Seq[SparkDataSourceRDDPartition] = inputPartitions.map { + case partition: SparkDataSourceRDDPartition => partition + case other => + throw new GlutenNotSupportException( + s"Unsupported partition type: ${other.getClass.getSimpleName}") + } + + val partitionedTasks = Array.fill(icebergPartitions.size)(mutable.ArrayBuffer.empty[ScanTask]) + + def getSparkInputPartitionContext( + inputPartition: SparkInputPartition): SparkPartitionContext = { + val clazz = classOf[SparkInputPartition] + def readField[T](fieldName: String): T = { + val field = clazz.getDeclaredField(fieldName) + field.setAccessible(true) + field.get(inputPartition).asInstanceOf[T] + } + + SparkPartitionContext( + groupingKeyType = readField[org.apache.iceberg.types.Types.StructType]("groupingKeyType"), + tableBroadcast = readField[Broadcast[Table]]("tableBroadcast"), + branch = inputPartition.branch(), + expectedSchemaString = readField[String]("expectedSchemaString"), + caseSensitive = inputPartition.isCaseSensitive, + preferredLocations = inputPartition.preferredLocations(), + cacheDeleteFilesOnExecutors = inputPartition.cacheDeleteFilesOnExecutors() + ) + } + + def getScanTasks(inputPartition: SparkInputPartition): Seq[ScanTask] = { + inputPartition.taskGroup[ScanTask]().tasks().asScala.toSeq.map { + case task if task.isFileScanTask => task + case task: CombinedScanTask => task + case other => + throw new GlutenNotSupportException( + s"Unsupported scan task type: ${other.getClass.getSimpleName}") + } + } + + def getScanTaskSize(scanTask: ScanTask): Long = scanTask match { + case task if task.isFileScanTask => task.asFileScanTask().length() + case task: CombinedScanTask => task.tasks().asScala.map(_.length()).sum + case other => + throw new GlutenNotSupportException( + s"Unsupported scan task type: ${other.getClass.getSimpleName}") + } + + def addToBucket( + heap: mutable.PriorityQueue[(Long, Int, Int)], + scanTask: ScanTask, + taskSize: Long): Unit = { + val (size, numFiles, idx) = heap.dequeue() + partitionedTasks(idx) += scanTask + heap.enqueue((size + taskSize, numFiles + 1, idx)) + } + + def initializeHeap( + ordering: Ordering[(Long, Int, Int)]): mutable.PriorityQueue[(Long, Int, Int)] = { + val heap = mutable.PriorityQueue.empty[(Long, Int, Int)](ordering) + icebergPartitions.indices.foreach(i => heap.enqueue((0L, 0, i))) + heap + } + + def createSparkInputPartition( + context: SparkPartitionContext, + tasks: Seq[ScanTask]): SparkInputPartition = { + val taskGroup = new BaseScanTaskGroup[ScanTask](TableScanUtil.mergeTasks(tasks.asJava)) + new SparkInputPartition( + context.groupingKeyType, + taskGroup, + context.tableBroadcast, + context.branch, + context.expectedSchemaString, + context.caseSensitive, + context.preferredLocations, + context.cacheDeleteFilesOnExecutors + ) + } + + val sparkInputPartitions = icebergPartitions.flatMap(_.inputPartitions).map { + case partition: SparkInputPartition => partition + case other => + throw new GlutenNotSupportException( + s"Unsupported input partition type: ${other.getClass.getSimpleName}") + } + + val context = getSparkInputPartitionContext(sparkInputPartitions.head) + val scanTasks = sparkInputPartitions.flatMap(getScanTasks) + val sortedScanTasks = scanTasks + .zip(scanTasks.map(getScanTaskSize)) + .sortBy(_._2)(Ordering.Long.reverse) + + val sizeFirstOrdering = Ordering + .by[(Long, Int, Int), (Long, Int)] { case (size, numFiles, _) => (size, numFiles) } + .reverse + + if (smallFileThreshold > 0) { + val smallFileTotalSize = sortedScanTasks.map(_._2).sum * smallFileThreshold + val numFirstOrdering = Ordering + .by[(Long, Int, Int), (Int, Long)] { case (size, numFiles, _) => (numFiles, size) } + .reverse + val heapByFileNum = initializeHeap(numFirstOrdering) + + var numSmallFiles = 0 + var smallFileSize = 0L + sortedScanTasks.reverseIterator + .takeWhile(task => task._2 + smallFileSize <= smallFileTotalSize) + .foreach { + case (task, taskSize) => + addToBucket(heapByFileNum, task, taskSize) + numSmallFiles += 1 + smallFileSize += taskSize + } + + val heapByFileSize = mutable.PriorityQueue.empty[(Long, Int, Int)](sizeFirstOrdering) + while (heapByFileNum.nonEmpty) { + heapByFileSize.enqueue(heapByFileNum.dequeue()) + } + + sortedScanTasks.take(sortedScanTasks.size - numSmallFiles).foreach { + case (task, taskSize) => + addToBucket(heapByFileSize, task, taskSize) + } + } else { + val heapByFileSize = initializeHeap(sizeFirstOrdering) + sortedScanTasks.foreach { + case (task, taskSize) => + addToBucket(heapByFileSize, task, taskSize) + } + } + + partitionedTasks.zipWithIndex.map { + case (tasks, idx) => + val newPartition = createSparkInputPartition(context, tasks.toSeq) + new SparkDataSourceRDDPartition(idx, Seq(newPartition)) + } + } } + +case class SparkPartitionContext( + groupingKeyType: org.apache.iceberg.types.Types.StructType, + tableBroadcast: Broadcast[Table], + branch: String, + expectedSchemaString: String, + caseSensitive: Boolean, + preferredLocations: Array[String], + cacheDeleteFilesOnExecutors: Boolean) diff --git a/gluten-iceberg/src/test/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtilSuite.scala b/gluten-iceberg/src/test/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtilSuite.scala new file mode 100644 index 000000000000..037de8d14794 --- /dev/null +++ b/gluten-iceberg/src/test/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtilSuite.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.iceberg.spark.source + +import org.apache.gluten.execution.SparkDataSourceRDDPartition + +import org.apache.spark.Partition +import org.apache.spark.broadcast.Broadcast + +import org.apache.iceberg.{FileScanTask, ScanTask, ScanTaskGroup, Table} +import org.apache.iceberg.types.Types +import org.mockito.Mockito.{mock, when} +import org.scalatest.funsuite.AnyFunSuite + +import java.util.Collections + +import scala.collection.JavaConverters._ + +class GlutenIcebergSourceUtilSuite extends AnyFunSuite { + + private def makeSparkInputPartition(length: Long): SparkInputPartition = { + val task = mock(classOf[FileScanTask]) + when(task.isFileScanTask).thenReturn(true) + when(task.asFileScanTask()).thenReturn(task) + when(task.length()).thenReturn(length) + + val taskGroup = new ScanTaskGroup[ScanTask] { + override def tasks(): java.util.Collection[ScanTask] = + Collections.singletonList(task) + + override def sizeBytes(): Long = length + + override def estimatedRowsCount(): Long = 0L + + override def filesCount(): Int = 1 + } + + val constructor = classOf[SparkInputPartition].getDeclaredConstructor( + classOf[Types.StructType], + classOf[ScanTaskGroup[_]], + classOf[Broadcast[Table]], + classOf[String], + classOf[String], + java.lang.Boolean.TYPE, + classOf[Array[String]], + java.lang.Boolean.TYPE + ) + constructor.setAccessible(true) + constructor.newInstance( + Types.StructType.of(), + taskGroup, + null, + null, + null, + Boolean.box(false), + Array.empty[String], + Boolean.box(false) + ) + } + + private def makePartitions( + inputPartitions: Seq[SparkInputPartition], + numPartitions: Int): Seq[Partition] = { + val numGroups = inputPartitions.size / numPartitions + + (if (inputPartitions.size % numPartitions == 0) 0 else 1) + inputPartitions.grouped(numGroups).toSeq.zipWithIndex.map { + case (partitions, idx) => new SparkDataSourceRDDPartition(idx, partitions) + } + } + + private def partitionLengths(partition: Partition): Seq[Long] = { + partition + .asInstanceOf[SparkDataSourceRDDPartition] + .inputPartitions + .flatMap( + _.asInstanceOf[SparkInputPartition] + .taskGroup[ScanTask]() + .tasks() + .asScala) + .map(_.asFileScanTask().length()) + } + + private def partitionFileNums(partition: Partition): Int = { + partition + .asInstanceOf[SparkDataSourceRDDPartition] + .inputPartitions + .map( + _.asInstanceOf[SparkInputPartition] + .taskGroup[ScanTask]() + .tasks() + .size) + .sum + } + + test("large files are distributed evenly by size") { + val inputPartitions = Seq(100L, 90L, 80L, 70L).map(makeSparkInputPartition) + val initialPartitions = makePartitions(inputPartitions, 2) + + val result = GlutenIcebergSourceUtil.regeneratePartitions(initialPartitions, 0.0) + + assert(result.size === 2) + + val sizes = result.map(partitionLengths(_).sum) + assert(sizes.forall(_ === 170)) + } + + test("small files are distributed evenly by number of files") { + val inputPartitions = Seq.fill(10)(10L).map(makeSparkInputPartition) + val initialPartitions = makePartitions(inputPartitions, 5) + + val result = GlutenIcebergSourceUtil.regeneratePartitions(initialPartitions, 1.0) + + assert(result.size === 5) + val counts = result.map(partitionLengths(_).size) + assert(counts.forall(_ === 2)) + } + + test("small files should not be placed into one partition") { + val inputPartitions = Seq(10L, 20L, 30L, 40L, 100L).map(makeSparkInputPartition) + val initialPartitions = makePartitions(inputPartitions, 2) + + val result = GlutenIcebergSourceUtil.regeneratePartitions(initialPartitions, 0.5) + + assert(result.size === 2) + assert(result.forall(partition => partitionLengths(partition).exists(_ <= 40))) + } + + test("mixed small and large files should be evenly distributed") { + val inputPartitions = + Seq(10L, 20L, 30L, 40L, 50L, 60L, 70L, 80L, 90L, 100L).map(makeSparkInputPartition) + val initialPartitions = makePartitions(inputPartitions, 3) + + val result = GlutenIcebergSourceUtil.regeneratePartitions(initialPartitions, 0.5) + + assert(result.size === 3) + assert(result.forall(partition => partitionFileNums(partition) >= 3)) + } + + test("zero length files") { + val inputPartitions = Seq(0L, 0L).map(makeSparkInputPartition) + val initialPartitions = makePartitions(inputPartitions, 2) + + val result = GlutenIcebergSourceUtil.regeneratePartitions(initialPartitions, 0.0) + + assert(result.size === 2) + assert(result.count(partitionLengths(_).nonEmpty) === 2) + } + + test("empty inputs") { + val result = GlutenIcebergSourceUtil.regeneratePartitions(Seq.empty, 0.5) + assert(result.size === 0) + } +} diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala index a0c3bb875753..90016f652a96 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala @@ -173,7 +173,7 @@ abstract class BatchScanExecTransformerBase( override def metricsUpdater(): MetricsUpdater = BackendsApiManager.getMetricsApiInstance.genBatchScanTransformerMetricsUpdater(metrics) - @transient protected lazy val finalPartitions: Seq[Partition] = + def getFinalPartitions: Seq[Partition] = SparkShimLoader.getSparkShims .orderPartitions( this, @@ -189,6 +189,8 @@ abstract class BatchScanExecTransformerBase( case (inputPartitions, index) => new SparkDataSourceRDDPartition(index, inputPartitions) } + @transient protected lazy val finalPartitions: Seq[Partition] = getFinalPartitions + @transient override lazy val fileFormat: ReadFileFormat = BackendsApiManager.getSettings.getSubstraitReadFileFormatV2(scan)