Skip to content

Conversation

shujingyang-db
Copy link
Contributor

@shujingyang-db shujingyang-db commented Aug 28, 2025

What changes were proposed in this pull request?

Currently, Spark's DataFrame repartition() API only supports hash-based and range-based partitioning strategies. Users who need precise control over which partition each row goes to (similar to RDD's partitionBy with custom partitioners) have no direct way to achieve this at the DataFrame level.

This PR introduces a new DataFrame API, repartitionById(col, numPartitions), an API that allows users to directly specify target partition IDs in DataFrame repartitioning operations:

// Partition rows based on a computed partition ID
val df = spark.range(100).withColumn("partition_id", col("id") % 10)
val repartitioned = df.repartitionById($"partition_id", 10)

Why are the changes needed?

Better DataFrame API

Does this PR introduce any user-facing change?

Yes.

How was this patch tested?

New Unit Tests in DataFrameSuite

Was this patch authored or co-authored using generative AI tooling?

No

@github-actions github-actions bot added the SQL label Aug 28, 2025
@HyukjinKwon HyukjinKwon changed the title [DRAFT][ SPARK-53401] Enable Direct Passthrough Partitioning in the DataFrame API [DRAFT][SPARK-53401] Enable Direct Passthrough Partitioning in the DataFrame API Aug 28, 2025
@@ -2045,6 +2045,19 @@ object functions {
*/
def spark_partition_id(): Column = Column.fn("spark_partition_id")

/**
* Returns the partition ID specified by the given column expression for direct shuffle
* partitioning. The input expression must evaluate to an integral type and must not be null.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this partition id be changed by AQE?

@shujingyang-db shujingyang-db marked this pull request as ready for review August 28, 2025 07:07
@shujingyang-db shujingyang-db changed the title [DRAFT][SPARK-53401] Enable Direct Passthrough Partitioning in the DataFrame API [SPARK-53401] Enable Direct Passthrough Partitioning in the DataFrame API Aug 28, 2025
*
* This partitioning maps directly to the PartitionIdPassthrough RDD partitioner.
*/
case class ShufflePartitionIdPassThrough(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could creating this on a column with high cardinality lead to a sudden increase in partitions? Will subsequent AQE rules try to act and reduce the number of partitions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, it will not reuse or remove shuffles. This is more to replace RDD's Partitioner API so people can completely migrate to DataFrame API. For the fact of performance and efficiency, it won't be super useful.

* @group typedrel
* @since 4.1.0
*/
def repartitionById(partitionIdExpr: Column): Dataset[T] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it's risky to provide a default numPartitions. Can we always ask users to specify numPartitions?

*/
def repartitionById(numPartitions: Int, partitionIdExpr: Column): Dataset[T] = {
val directShufflePartitionIdCol = Column(DirectShufflePartitionID(partitionIdExpr.expr))
repartitionByExpression(Some(numPartitions), Seq(directShufflePartitionIdCol))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can create RepartitionByExpression directly with a special flag to indicate pass through, then we don't need DirectShufflePartitionID.

val e = intercept[SparkException] {
repartitioned.collect()
}
assert(e.getCause.isInstanceOf[IllegalArgumentException])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the actual error? if the error message is not clear we should do explicit null check, or simply treat null as partition id 0.

@@ -1406,6 +1406,87 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
assert(planned.exists(_.isInstanceOf[GlobalLimitExec]))
assert(planned.exists(_.isInstanceOf[LocalLimitExec]))
}

test("SPARK-53401: repartitionById should throw an exception for negative partition id") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, shall we use pmod then? then the partition id is always positive, see https://docs.databricks.com/aws/en/sql/language-manual/functions/pmod

assert(e.getMessage.contains("Index -5 out of bounds"))
}

test("SPARK-53401: repartitionById should throw an exception for partition id >= numPartitions") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, how can this happen if we do mod/pmod?

val df = spark.range(100).select($"id" % 10 as "key", $"id" as "value")
val grouped =
df.repartitionById(10, $"key")
.filter($"value" > 50).groupBy($"key").count()
Copy link
Contributor

@cloud-fan cloud-fan Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so what this test proves is that Filter can propagate child's output partitioning, which is already proven by other tests and we don't need to verify it again here.

checkShuffleCount(grouped, 1)
}

test("SPARK-53401: shuffle reuse after a join that preserves partitioning") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a more interesting test is to prove that a join with id pass-through and hash partitioning will still do a shuffle on the id pass-through side.

@HyukjinKwon HyukjinKwon changed the title [SPARK-53401] Enable Direct Passthrough Partitioning in the DataFrame API [SPARK-53401][SQL] Enable Direct Passthrough Partitioning in the DataFrame API Aug 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants