diff --git a/gluten-iceberg/pom.xml b/gluten-iceberg/pom.xml
index 51aa0d92bc34..366340aa695a 100644
--- a/gluten-iceberg/pom.xml
+++ b/gluten-iceberg/pom.xml
@@ -101,6 +101,12 @@
scalatest_${scala.binary.version}
test
+
+ org.mockito
+ mockito-core
+ 2.23.4
+ test
+
diff --git a/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala b/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala
index 8f6dcda7edff..082865da83dc 100644
--- a/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala
+++ b/gluten-iceberg/src/main/scala/org/apache/gluten/execution/IcebergScanTransformer.scala
@@ -17,6 +17,7 @@
package org.apache.gluten.execution
import org.apache.gluten.backendsapi.BackendsApiManager
+import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution.IcebergScanTransformer.{containsMetadataColumn, containsUuidOrFixedType}
import org.apache.gluten.sql.shims.SparkShimLoader
@@ -195,6 +196,16 @@ case class IcebergScanTransformer(
metadataColumns ++ extraMetadataColumns
}
+ @transient override protected lazy val finalPartitions: Seq[Partition] = {
+ if (keyGroupedPartitioning.isDefined) {
+ getFinalPartitions
+ } else {
+ GlutenIcebergSourceUtil.regeneratePartitions(
+ getFinalPartitions,
+ GlutenConfig.get.smallFileThreshold)
+ }
+ }
+
override lazy val fileFormat: ReadFileFormat = GlutenIcebergSourceUtil.getFileFormat(scan)
override def getSplitInfosFromPartitions(
diff --git a/gluten-iceberg/src/main/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtil.scala b/gluten-iceberg/src/main/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtil.scala
index 1a658a78925e..13e8e01268f5 100644
--- a/gluten-iceberg/src/main/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtil.scala
+++ b/gluten-iceberg/src/main/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtil.scala
@@ -22,6 +22,9 @@ import org.apache.gluten.execution.SparkDataSourceRDDPartition
import org.apache.gluten.substrait.rel.{IcebergLocalFilesBuilder, SplitInfo}
import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat
+import org.apache.spark.Partition
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
import org.apache.spark.softaffinity.SoftAffinity
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils
import org.apache.spark.sql.connector.read.Scan
@@ -29,14 +32,16 @@ import org.apache.spark.sql.types.StructType
import org.apache.iceberg._
import org.apache.iceberg.spark.SparkSchemaUtil
+import org.apache.iceberg.util.TableScanUtil
import java.lang.{Class, Long => JLong}
import java.util.{ArrayList => JArrayList, HashMap => JHashMap, List => JList, Map => JMap}
import java.util.Locale
import scala.collection.JavaConverters._
+import scala.collection.mutable
-object GlutenIcebergSourceUtil {
+object GlutenIcebergSourceUtil extends Logging {
private val InputFileNameCol = "input_file_name"
private val InputFileBlockStartCol = "input_file_block_start"
private val InputFileBlockLengthCol = "input_file_block_length"
@@ -234,4 +239,158 @@ object GlutenIcebergSourceUtil {
case _ =>
throw new GlutenNotSupportException("Iceberg Only support parquet and orc file format.")
}
+
+ def regeneratePartitions(
+ inputPartitions: Seq[Partition],
+ smallFileThreshold: Double): Seq[Partition] = {
+ if (inputPartitions.isEmpty) {
+ return Seq.empty
+ }
+
+ val icebergPartitions: Seq[SparkDataSourceRDDPartition] = inputPartitions.map {
+ case partition: SparkDataSourceRDDPartition => partition
+ case other =>
+ throw new GlutenNotSupportException(
+ s"Unsupported partition type: ${other.getClass.getSimpleName}")
+ }
+
+ val partitionedTasks = Array.fill(icebergPartitions.size)(mutable.ArrayBuffer.empty[ScanTask])
+
+ def getSparkInputPartitionContext(
+ inputPartition: SparkInputPartition): SparkPartitionContext = {
+ val clazz = classOf[SparkInputPartition]
+ def readField[T](fieldName: String): T = {
+ val field = clazz.getDeclaredField(fieldName)
+ field.setAccessible(true)
+ field.get(inputPartition).asInstanceOf[T]
+ }
+
+ SparkPartitionContext(
+ groupingKeyType = readField[org.apache.iceberg.types.Types.StructType]("groupingKeyType"),
+ tableBroadcast = readField[Broadcast[Table]]("tableBroadcast"),
+ branch = inputPartition.branch(),
+ expectedSchemaString = readField[String]("expectedSchemaString"),
+ caseSensitive = inputPartition.isCaseSensitive,
+ preferredLocations = inputPartition.preferredLocations(),
+ cacheDeleteFilesOnExecutors = inputPartition.cacheDeleteFilesOnExecutors()
+ )
+ }
+
+ def getScanTasks(inputPartition: SparkInputPartition): Seq[ScanTask] = {
+ inputPartition.taskGroup[ScanTask]().tasks().asScala.toSeq.map {
+ case task if task.isFileScanTask => task
+ case task: CombinedScanTask => task
+ case other =>
+ throw new GlutenNotSupportException(
+ s"Unsupported scan task type: ${other.getClass.getSimpleName}")
+ }
+ }
+
+ def getScanTaskSize(scanTask: ScanTask): Long = scanTask match {
+ case task if task.isFileScanTask => task.asFileScanTask().length()
+ case task: CombinedScanTask => task.tasks().asScala.map(_.length()).sum
+ case other =>
+ throw new GlutenNotSupportException(
+ s"Unsupported scan task type: ${other.getClass.getSimpleName}")
+ }
+
+ def addToBucket(
+ heap: mutable.PriorityQueue[(Long, Int, Int)],
+ scanTask: ScanTask,
+ taskSize: Long): Unit = {
+ val (size, numFiles, idx) = heap.dequeue()
+ partitionedTasks(idx) += scanTask
+ heap.enqueue((size + taskSize, numFiles + 1, idx))
+ }
+
+ def initializeHeap(
+ ordering: Ordering[(Long, Int, Int)]): mutable.PriorityQueue[(Long, Int, Int)] = {
+ val heap = mutable.PriorityQueue.empty[(Long, Int, Int)](ordering)
+ icebergPartitions.indices.foreach(i => heap.enqueue((0L, 0, i)))
+ heap
+ }
+
+ def createSparkInputPartition(
+ context: SparkPartitionContext,
+ tasks: Seq[ScanTask]): SparkInputPartition = {
+ val taskGroup = new BaseScanTaskGroup[ScanTask](TableScanUtil.mergeTasks(tasks.asJava))
+ new SparkInputPartition(
+ context.groupingKeyType,
+ taskGroup,
+ context.tableBroadcast,
+ context.branch,
+ context.expectedSchemaString,
+ context.caseSensitive,
+ context.preferredLocations,
+ context.cacheDeleteFilesOnExecutors
+ )
+ }
+
+ val sparkInputPartitions = icebergPartitions.flatMap(_.inputPartitions).map {
+ case partition: SparkInputPartition => partition
+ case other =>
+ throw new GlutenNotSupportException(
+ s"Unsupported input partition type: ${other.getClass.getSimpleName}")
+ }
+
+ val context = getSparkInputPartitionContext(sparkInputPartitions.head)
+ val scanTasks = sparkInputPartitions.flatMap(getScanTasks)
+ val sortedScanTasks = scanTasks
+ .zip(scanTasks.map(getScanTaskSize))
+ .sortBy(_._2)(Ordering.Long.reverse)
+
+ val sizeFirstOrdering = Ordering
+ .by[(Long, Int, Int), (Long, Int)] { case (size, numFiles, _) => (size, numFiles) }
+ .reverse
+
+ if (smallFileThreshold > 0) {
+ val smallFileTotalSize = sortedScanTasks.map(_._2).sum * smallFileThreshold
+ val numFirstOrdering = Ordering
+ .by[(Long, Int, Int), (Int, Long)] { case (size, numFiles, _) => (numFiles, size) }
+ .reverse
+ val heapByFileNum = initializeHeap(numFirstOrdering)
+
+ var numSmallFiles = 0
+ var smallFileSize = 0L
+ sortedScanTasks.reverseIterator
+ .takeWhile(task => task._2 + smallFileSize <= smallFileTotalSize)
+ .foreach {
+ case (task, taskSize) =>
+ addToBucket(heapByFileNum, task, taskSize)
+ numSmallFiles += 1
+ smallFileSize += taskSize
+ }
+
+ val heapByFileSize = mutable.PriorityQueue.empty[(Long, Int, Int)](sizeFirstOrdering)
+ while (heapByFileNum.nonEmpty) {
+ heapByFileSize.enqueue(heapByFileNum.dequeue())
+ }
+
+ sortedScanTasks.take(sortedScanTasks.size - numSmallFiles).foreach {
+ case (task, taskSize) =>
+ addToBucket(heapByFileSize, task, taskSize)
+ }
+ } else {
+ val heapByFileSize = initializeHeap(sizeFirstOrdering)
+ sortedScanTasks.foreach {
+ case (task, taskSize) =>
+ addToBucket(heapByFileSize, task, taskSize)
+ }
+ }
+
+ partitionedTasks.zipWithIndex.map {
+ case (tasks, idx) =>
+ val newPartition = createSparkInputPartition(context, tasks.toSeq)
+ new SparkDataSourceRDDPartition(idx, Seq(newPartition))
+ }
+ }
}
+
+case class SparkPartitionContext(
+ groupingKeyType: org.apache.iceberg.types.Types.StructType,
+ tableBroadcast: Broadcast[Table],
+ branch: String,
+ expectedSchemaString: String,
+ caseSensitive: Boolean,
+ preferredLocations: Array[String],
+ cacheDeleteFilesOnExecutors: Boolean)
diff --git a/gluten-iceberg/src/test/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtilSuite.scala b/gluten-iceberg/src/test/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtilSuite.scala
new file mode 100644
index 000000000000..037de8d14794
--- /dev/null
+++ b/gluten-iceberg/src/test/scala/org/apache/iceberg/spark/source/GlutenIcebergSourceUtilSuite.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.iceberg.spark.source
+
+import org.apache.gluten.execution.SparkDataSourceRDDPartition
+
+import org.apache.spark.Partition
+import org.apache.spark.broadcast.Broadcast
+
+import org.apache.iceberg.{FileScanTask, ScanTask, ScanTaskGroup, Table}
+import org.apache.iceberg.types.Types
+import org.mockito.Mockito.{mock, when}
+import org.scalatest.funsuite.AnyFunSuite
+
+import java.util.Collections
+
+import scala.collection.JavaConverters._
+
+class GlutenIcebergSourceUtilSuite extends AnyFunSuite {
+
+ private def makeSparkInputPartition(length: Long): SparkInputPartition = {
+ val task = mock(classOf[FileScanTask])
+ when(task.isFileScanTask).thenReturn(true)
+ when(task.asFileScanTask()).thenReturn(task)
+ when(task.length()).thenReturn(length)
+
+ val taskGroup = new ScanTaskGroup[ScanTask] {
+ override def tasks(): java.util.Collection[ScanTask] =
+ Collections.singletonList(task)
+
+ override def sizeBytes(): Long = length
+
+ override def estimatedRowsCount(): Long = 0L
+
+ override def filesCount(): Int = 1
+ }
+
+ val constructor = classOf[SparkInputPartition].getDeclaredConstructor(
+ classOf[Types.StructType],
+ classOf[ScanTaskGroup[_]],
+ classOf[Broadcast[Table]],
+ classOf[String],
+ classOf[String],
+ java.lang.Boolean.TYPE,
+ classOf[Array[String]],
+ java.lang.Boolean.TYPE
+ )
+ constructor.setAccessible(true)
+ constructor.newInstance(
+ Types.StructType.of(),
+ taskGroup,
+ null,
+ null,
+ null,
+ Boolean.box(false),
+ Array.empty[String],
+ Boolean.box(false)
+ )
+ }
+
+ private def makePartitions(
+ inputPartitions: Seq[SparkInputPartition],
+ numPartitions: Int): Seq[Partition] = {
+ val numGroups = inputPartitions.size / numPartitions +
+ (if (inputPartitions.size % numPartitions == 0) 0 else 1)
+ inputPartitions.grouped(numGroups).toSeq.zipWithIndex.map {
+ case (partitions, idx) => new SparkDataSourceRDDPartition(idx, partitions)
+ }
+ }
+
+ private def partitionLengths(partition: Partition): Seq[Long] = {
+ partition
+ .asInstanceOf[SparkDataSourceRDDPartition]
+ .inputPartitions
+ .flatMap(
+ _.asInstanceOf[SparkInputPartition]
+ .taskGroup[ScanTask]()
+ .tasks()
+ .asScala)
+ .map(_.asFileScanTask().length())
+ }
+
+ private def partitionFileNums(partition: Partition): Int = {
+ partition
+ .asInstanceOf[SparkDataSourceRDDPartition]
+ .inputPartitions
+ .map(
+ _.asInstanceOf[SparkInputPartition]
+ .taskGroup[ScanTask]()
+ .tasks()
+ .size)
+ .sum
+ }
+
+ test("large files are distributed evenly by size") {
+ val inputPartitions = Seq(100L, 90L, 80L, 70L).map(makeSparkInputPartition)
+ val initialPartitions = makePartitions(inputPartitions, 2)
+
+ val result = GlutenIcebergSourceUtil.regeneratePartitions(initialPartitions, 0.0)
+
+ assert(result.size === 2)
+
+ val sizes = result.map(partitionLengths(_).sum)
+ assert(sizes.forall(_ === 170))
+ }
+
+ test("small files are distributed evenly by number of files") {
+ val inputPartitions = Seq.fill(10)(10L).map(makeSparkInputPartition)
+ val initialPartitions = makePartitions(inputPartitions, 5)
+
+ val result = GlutenIcebergSourceUtil.regeneratePartitions(initialPartitions, 1.0)
+
+ assert(result.size === 5)
+ val counts = result.map(partitionLengths(_).size)
+ assert(counts.forall(_ === 2))
+ }
+
+ test("small files should not be placed into one partition") {
+ val inputPartitions = Seq(10L, 20L, 30L, 40L, 100L).map(makeSparkInputPartition)
+ val initialPartitions = makePartitions(inputPartitions, 2)
+
+ val result = GlutenIcebergSourceUtil.regeneratePartitions(initialPartitions, 0.5)
+
+ assert(result.size === 2)
+ assert(result.forall(partition => partitionLengths(partition).exists(_ <= 40)))
+ }
+
+ test("mixed small and large files should be evenly distributed") {
+ val inputPartitions =
+ Seq(10L, 20L, 30L, 40L, 50L, 60L, 70L, 80L, 90L, 100L).map(makeSparkInputPartition)
+ val initialPartitions = makePartitions(inputPartitions, 3)
+
+ val result = GlutenIcebergSourceUtil.regeneratePartitions(initialPartitions, 0.5)
+
+ assert(result.size === 3)
+ assert(result.forall(partition => partitionFileNums(partition) >= 3))
+ }
+
+ test("zero length files") {
+ val inputPartitions = Seq(0L, 0L).map(makeSparkInputPartition)
+ val initialPartitions = makePartitions(inputPartitions, 2)
+
+ val result = GlutenIcebergSourceUtil.regeneratePartitions(initialPartitions, 0.0)
+
+ assert(result.size === 2)
+ assert(result.count(partitionLengths(_).nonEmpty) === 2)
+ }
+
+ test("empty inputs") {
+ val result = GlutenIcebergSourceUtil.regeneratePartitions(Seq.empty, 0.5)
+ assert(result.size === 0)
+ }
+}
diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala
index a0c3bb875753..90016f652a96 100644
--- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala
+++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BatchScanExecTransformer.scala
@@ -173,7 +173,7 @@ abstract class BatchScanExecTransformerBase(
override def metricsUpdater(): MetricsUpdater =
BackendsApiManager.getMetricsApiInstance.genBatchScanTransformerMetricsUpdater(metrics)
- @transient protected lazy val finalPartitions: Seq[Partition] =
+ def getFinalPartitions: Seq[Partition] =
SparkShimLoader.getSparkShims
.orderPartitions(
this,
@@ -189,6 +189,8 @@ abstract class BatchScanExecTransformerBase(
case (inputPartitions, index) => new SparkDataSourceRDDPartition(index, inputPartitions)
}
+ @transient protected lazy val finalPartitions: Seq[Partition] = getFinalPartitions
+
@transient override lazy val fileFormat: ReadFileFormat =
BackendsApiManager.getSettings.getSubstraitReadFileFormatV2(scan)