Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions gluten-iceberg/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@
<artifactId>scalatest_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>2.23.4</version>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution.IcebergScanTransformer.{containsMetadataColumn, containsUuidOrFixedType}
import org.apache.gluten.sql.shims.SparkShimLoader
Expand Down Expand Up @@ -195,6 +196,16 @@ case class IcebergScanTransformer(
metadataColumns ++ extraMetadataColumns
}

@transient override protected lazy val finalPartitions: Seq[Partition] = {
if (keyGroupedPartitioning.isDefined) {
getFinalPartitions
} else {
GlutenIcebergSourceUtil.regeneratePartitions(
getFinalPartitions,
GlutenConfig.get.smallFileThreshold)
}
}

override lazy val fileFormat: ReadFileFormat = GlutenIcebergSourceUtil.getFileFormat(scan)

override def getSplitInfosFromPartitions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,26 @@ import org.apache.gluten.execution.SparkDataSourceRDDPartition
import org.apache.gluten.substrait.rel.{IcebergLocalFilesBuilder, SplitInfo}
import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat

import org.apache.spark.Partition
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.softaffinity.SoftAffinity
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.types.StructType

import org.apache.iceberg._
import org.apache.iceberg.spark.SparkSchemaUtil
import org.apache.iceberg.util.TableScanUtil

import java.lang.{Class, Long => JLong}
import java.util.{ArrayList => JArrayList, HashMap => JHashMap, List => JList, Map => JMap}
import java.util.Locale

import scala.collection.JavaConverters._
import scala.collection.mutable

object GlutenIcebergSourceUtil {
object GlutenIcebergSourceUtil extends Logging {
private val InputFileNameCol = "input_file_name"
private val InputFileBlockStartCol = "input_file_block_start"
private val InputFileBlockLengthCol = "input_file_block_length"
Expand Down Expand Up @@ -234,4 +239,158 @@ object GlutenIcebergSourceUtil {
case _ =>
throw new GlutenNotSupportException("Iceberg Only support parquet and orc file format.")
}

def regeneratePartitions(
inputPartitions: Seq[Partition],
smallFileThreshold: Double): Seq[Partition] = {
if (inputPartitions.isEmpty) {
return Seq.empty
}

val icebergPartitions: Seq[SparkDataSourceRDDPartition] = inputPartitions.map {
case partition: SparkDataSourceRDDPartition => partition
case other =>
throw new GlutenNotSupportException(
s"Unsupported partition type: ${other.getClass.getSimpleName}")
}

val partitionedTasks = Array.fill(icebergPartitions.size)(mutable.ArrayBuffer.empty[ScanTask])

def getSparkInputPartitionContext(
inputPartition: SparkInputPartition): SparkPartitionContext = {
val clazz = classOf[SparkInputPartition]
def readField[T](fieldName: String): T = {
val field = clazz.getDeclaredField(fieldName)
field.setAccessible(true)
field.get(inputPartition).asInstanceOf[T]
}

SparkPartitionContext(
groupingKeyType = readField[org.apache.iceberg.types.Types.StructType]("groupingKeyType"),
tableBroadcast = readField[Broadcast[Table]]("tableBroadcast"),
branch = inputPartition.branch(),
expectedSchemaString = readField[String]("expectedSchemaString"),
caseSensitive = inputPartition.isCaseSensitive,
preferredLocations = inputPartition.preferredLocations(),
cacheDeleteFilesOnExecutors = inputPartition.cacheDeleteFilesOnExecutors()
)
}

def getScanTasks(inputPartition: SparkInputPartition): Seq[ScanTask] = {
inputPartition.taskGroup[ScanTask]().tasks().asScala.toSeq.map {
case task if task.isFileScanTask => task
case task: CombinedScanTask => task
case other =>
throw new GlutenNotSupportException(
s"Unsupported scan task type: ${other.getClass.getSimpleName}")
}
}

def getScanTaskSize(scanTask: ScanTask): Long = scanTask match {
case task if task.isFileScanTask => task.asFileScanTask().length()
case task: CombinedScanTask => task.tasks().asScala.map(_.length()).sum
case other =>
throw new GlutenNotSupportException(
s"Unsupported scan task type: ${other.getClass.getSimpleName}")
}

def addToBucket(
heap: mutable.PriorityQueue[(Long, Int, Int)],
scanTask: ScanTask,
taskSize: Long): Unit = {
val (size, numFiles, idx) = heap.dequeue()
partitionedTasks(idx) += scanTask
heap.enqueue((size + taskSize, numFiles + 1, idx))
}

def initializeHeap(
ordering: Ordering[(Long, Int, Int)]): mutable.PriorityQueue[(Long, Int, Int)] = {
val heap = mutable.PriorityQueue.empty[(Long, Int, Int)](ordering)
icebergPartitions.indices.foreach(i => heap.enqueue((0L, 0, i)))
heap
}

def createSparkInputPartition(
context: SparkPartitionContext,
tasks: Seq[ScanTask]): SparkInputPartition = {
val taskGroup = new BaseScanTaskGroup[ScanTask](TableScanUtil.mergeTasks(tasks.asJava))
new SparkInputPartition(
context.groupingKeyType,
taskGroup,
context.tableBroadcast,
context.branch,
context.expectedSchemaString,
context.caseSensitive,
context.preferredLocations,
context.cacheDeleteFilesOnExecutors
)
}

val sparkInputPartitions = icebergPartitions.flatMap(_.inputPartitions).map {
case partition: SparkInputPartition => partition
case other =>
throw new GlutenNotSupportException(
s"Unsupported input partition type: ${other.getClass.getSimpleName}")
}

val context = getSparkInputPartitionContext(sparkInputPartitions.head)
val scanTasks = sparkInputPartitions.flatMap(getScanTasks)
val sortedScanTasks = scanTasks
.zip(scanTasks.map(getScanTaskSize))
.sortBy(_._2)(Ordering.Long.reverse)

val sizeFirstOrdering = Ordering
.by[(Long, Int, Int), (Long, Int)] { case (size, numFiles, _) => (size, numFiles) }
.reverse

if (smallFileThreshold > 0) {
val smallFileTotalSize = sortedScanTasks.map(_._2).sum * smallFileThreshold
val numFirstOrdering = Ordering
.by[(Long, Int, Int), (Int, Long)] { case (size, numFiles, _) => (numFiles, size) }
.reverse
val heapByFileNum = initializeHeap(numFirstOrdering)

var numSmallFiles = 0
var smallFileSize = 0L
sortedScanTasks.reverseIterator
.takeWhile(task => task._2 + smallFileSize <= smallFileTotalSize)
.foreach {
case (task, taskSize) =>
addToBucket(heapByFileNum, task, taskSize)
numSmallFiles += 1
smallFileSize += taskSize
}

val heapByFileSize = mutable.PriorityQueue.empty[(Long, Int, Int)](sizeFirstOrdering)
while (heapByFileNum.nonEmpty) {
heapByFileSize.enqueue(heapByFileNum.dequeue())
}

sortedScanTasks.take(sortedScanTasks.size - numSmallFiles).foreach {
case (task, taskSize) =>
addToBucket(heapByFileSize, task, taskSize)
}
} else {
val heapByFileSize = initializeHeap(sizeFirstOrdering)
sortedScanTasks.foreach {
case (task, taskSize) =>
addToBucket(heapByFileSize, task, taskSize)
}
}

partitionedTasks.zipWithIndex.map {
case (tasks, idx) =>
val newPartition = createSparkInputPartition(context, tasks.toSeq)
new SparkDataSourceRDDPartition(idx, Seq(newPartition))
}
}
}

case class SparkPartitionContext(
groupingKeyType: org.apache.iceberg.types.Types.StructType,
tableBroadcast: Broadcast[Table],
branch: String,
expectedSchemaString: String,
caseSensitive: Boolean,
preferredLocations: Array[String],
cacheDeleteFilesOnExecutors: Boolean)
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.iceberg.spark.source

import org.apache.gluten.execution.SparkDataSourceRDDPartition

import org.apache.spark.Partition
import org.apache.spark.broadcast.Broadcast

import org.apache.iceberg.{FileScanTask, ScanTask, ScanTaskGroup, Table}
import org.apache.iceberg.types.Types
import org.mockito.Mockito.{mock, when}
import org.scalatest.funsuite.AnyFunSuite

import java.util.Collections

import scala.collection.JavaConverters._

class GlutenIcebergSourceUtilSuite extends AnyFunSuite {

private def makeSparkInputPartition(length: Long): SparkInputPartition = {
val task = mock(classOf[FileScanTask])
when(task.isFileScanTask).thenReturn(true)
when(task.asFileScanTask()).thenReturn(task)
when(task.length()).thenReturn(length)

val taskGroup = new ScanTaskGroup[ScanTask] {
override def tasks(): java.util.Collection[ScanTask] =
Collections.singletonList(task)

override def sizeBytes(): Long = length

override def estimatedRowsCount(): Long = 0L

override def filesCount(): Int = 1
}

val constructor = classOf[SparkInputPartition].getDeclaredConstructor(
classOf[Types.StructType],
classOf[ScanTaskGroup[_]],
classOf[Broadcast[Table]],
classOf[String],
classOf[String],
java.lang.Boolean.TYPE,
classOf[Array[String]],
java.lang.Boolean.TYPE
)
constructor.setAccessible(true)
constructor.newInstance(
Types.StructType.of(),
taskGroup,
null,
null,
null,
Boolean.box(false),
Array.empty[String],
Boolean.box(false)
)
}

private def makePartitions(
inputPartitions: Seq[SparkInputPartition],
numPartitions: Int): Seq[Partition] = {
val numGroups = inputPartitions.size / numPartitions +
(if (inputPartitions.size % numPartitions == 0) 0 else 1)
inputPartitions.grouped(numGroups).toSeq.zipWithIndex.map {
case (partitions, idx) => new SparkDataSourceRDDPartition(idx, partitions)
}
}

private def partitionLengths(partition: Partition): Seq[Long] = {
partition
.asInstanceOf[SparkDataSourceRDDPartition]
.inputPartitions
.flatMap(
_.asInstanceOf[SparkInputPartition]
.taskGroup[ScanTask]()
.tasks()
.asScala)
.map(_.asFileScanTask().length())
}

private def partitionFileNums(partition: Partition): Int = {
partition
.asInstanceOf[SparkDataSourceRDDPartition]
.inputPartitions
.map(
_.asInstanceOf[SparkInputPartition]
.taskGroup[ScanTask]()
.tasks()
.size)
.sum
}

test("large files are distributed evenly by size") {
val inputPartitions = Seq(100L, 90L, 80L, 70L).map(makeSparkInputPartition)
val initialPartitions = makePartitions(inputPartitions, 2)

val result = GlutenIcebergSourceUtil.regeneratePartitions(initialPartitions, 0.0)

assert(result.size === 2)

val sizes = result.map(partitionLengths(_).sum)
assert(sizes.forall(_ === 170))
}

test("small files are distributed evenly by number of files") {
val inputPartitions = Seq.fill(10)(10L).map(makeSparkInputPartition)
val initialPartitions = makePartitions(inputPartitions, 5)

val result = GlutenIcebergSourceUtil.regeneratePartitions(initialPartitions, 1.0)

assert(result.size === 5)
val counts = result.map(partitionLengths(_).size)
assert(counts.forall(_ === 2))
}

test("small files should not be placed into one partition") {
val inputPartitions = Seq(10L, 20L, 30L, 40L, 100L).map(makeSparkInputPartition)
val initialPartitions = makePartitions(inputPartitions, 2)

val result = GlutenIcebergSourceUtil.regeneratePartitions(initialPartitions, 0.5)

assert(result.size === 2)
assert(result.forall(partition => partitionLengths(partition).exists(_ <= 40)))
}

test("mixed small and large files should be evenly distributed") {
val inputPartitions =
Seq(10L, 20L, 30L, 40L, 50L, 60L, 70L, 80L, 90L, 100L).map(makeSparkInputPartition)
val initialPartitions = makePartitions(inputPartitions, 3)

val result = GlutenIcebergSourceUtil.regeneratePartitions(initialPartitions, 0.5)

assert(result.size === 3)
assert(result.forall(partition => partitionFileNums(partition) >= 3))
}

test("zero length files") {
val inputPartitions = Seq(0L, 0L).map(makeSparkInputPartition)
val initialPartitions = makePartitions(inputPartitions, 2)

val result = GlutenIcebergSourceUtil.regeneratePartitions(initialPartitions, 0.0)

assert(result.size === 2)
assert(result.count(partitionLengths(_).nonEmpty) === 2)
}

test("empty inputs") {
val result = GlutenIcebergSourceUtil.regeneratePartitions(Seq.empty, 0.5)
assert(result.size === 0)
}
}
Loading
Loading