Skip to content

Commit

Permalink
ZipPartitions for arbitrary number of RDDs.
Browse files Browse the repository at this point in the history
  • Loading branch information
kyle-winkelman committed Jan 24, 2025
1 parent 44966c9 commit 72e8f05
Show file tree
Hide file tree
Showing 8 changed files with 547 additions and 159 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.java.function;

import java.io.Serializable;
import java.util.Iterator;

/**
* A function that takes three inputs and returns zero or more output records.
*/
@FunctionalInterface
public interface FlatMapFunction3<T1, T2, T3, R> extends Serializable {
Iterator<R> call(T1 t1, T2 t2, T3 t3) throws Exception;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.java.function;

import java.io.Serializable;
import java.util.Iterator;

/**
* A function that takes three inputs and returns zero or more output records.
*/
@FunctionalInterface
public interface FlatMapFunction4<T1, T2, T3, T4, R> extends Serializable {
Iterator<R> call(T1 t1, T2 t2, T3 t3, T4 t4) throws Exception;
}
81 changes: 73 additions & 8 deletions core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.{lang => jl}
import java.lang.{Iterable => JIterable}
import java.util.{Comparator, Iterator => JIterator, List => JList, Map => JMap}

import scala.annotation.varargs
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag

Expand All @@ -38,6 +39,7 @@ import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.Utils


/**
* As a workaround for https://issues.scala-lang.org/browse/SI-8905, implementations
* of JavaRDDLike should extend this dummy abstract class instead of directly inheriting
Expand Down Expand Up @@ -308,20 +310,83 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classTag))(classTag, other.classTag)
}

/**
* Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by
* applying a function to the zipped partitions. Assumes that all the RDDs have the
* *same number of partitions*, but does *not* require them to have the same number
* of elements in each partition.
/*
* Zip this RDD's partitions with one other RDD and return a new RDD by applying a function to
* the zipped partitions. Assumes that both the RDDs have the *same number of partitions*, but
* does *not* require them to have the same number of elements in each partition.
*/
def zipPartitions[U, V](
other: JavaRDDLike[U, _],
f: FlatMapFunction2[JIterator[T], JIterator[U], V]): JavaRDD[V] = {
def fn: (Iterator[T], Iterator[U]) => Iterator[V] = {
(x: Iterator[T], y: Iterator[U]) => f.call(x.asJava, y.asJava).asScala
def fn: (Iterator[T], Iterator[U]) => Iterator[V] = { (x: Iterator[T], y: Iterator[U]) =>
f.call(x.asJava, y.asJava).asScala
}
JavaRDD
.fromRDD(rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V])
}

/**
* Zip this RDD's partitions with two more RDDs and return a new RDD by applying a function to
* the zipped partitions. Assumes that all the RDDs have the *same number of partitions*, but
* does *not* require them to have the same number of elements in each partition.
*/
@Since("4.1.0")
def zipPartitions[U1, U2, V](
other1: JavaRDDLike[U1, _],
other2: JavaRDDLike[U2, _],
f: FlatMapFunction3[JIterator[T], JIterator[U1], JIterator[U2], V]): JavaRDD[V] = {
def fn: (Iterator[T], Iterator[U1], Iterator[U2]) => Iterator[V] =
(t: Iterator[T], u1: Iterator[U1], u2: Iterator[U2]) =>
f.call(t.asJava, u1.asJava, u2.asJava).asScala
JavaRDD.fromRDD(
rdd.zipPartitions(other1.rdd, other2.rdd)(fn)(
other1.classTag,
other2.classTag,
fakeClassTag[V]))(fakeClassTag[V])
}

/**
* Zip this RDD's partitions with three more RDDs and return a new RDD by applying a function to
* the zipped partitions. Assumes that all the RDDs have the *same number of partitions*, but
* does *not* require them to have the same number of elements in each partition.
*/
@Since("4.1.0")
def zipPartitions[U1, U2, U3, V](
other1: JavaRDDLike[U1, _],
other2: JavaRDDLike[U2, _],
other3: JavaRDDLike[U3, _],
f: FlatMapFunction4[JIterator[T], JIterator[U1], JIterator[U2], JIterator[U3], V])
: JavaRDD[V] = {
def fn: (Iterator[T], Iterator[U1], Iterator[U2], Iterator[U3]) => Iterator[V] =
(t: Iterator[T], u1: Iterator[U1], u2: Iterator[U2], u3: Iterator[U3]) =>
f.call(t.asJava, u1.asJava, u2.asJava, u3.asJava).asScala
JavaRDD.fromRDD(
rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V])
rdd.zipPartitions(other1.rdd, other2.rdd, other3.rdd)(fn)(
other1.classTag,
other2.classTag,
other3.classTag,
fakeClassTag[V]))(fakeClassTag[V])
}

/**
* Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by applying a
* function to the zipped partitions. Assumes that all the RDDs have the *same number of
* partitions*, but does *not* require them to have the same number of elements in each
* partition.
*
* @note
* A generic version of `zipPartitions` for an arbitrary number of RDDs. It may be type unsafe
* and other `zipPartitions` methods should be preferred.
*/
@Since("4.1.0")
@varargs
def zipPartitions[U, V](
f: FlatMapFunction[JList[JIterator[U]], V],
others: JavaRDDLike[_, _]*): JavaRDD[V] = {
def fn: Seq[Iterator[_]] => Iterator[V] =
(i: Seq[Iterator[_]]) => f.call(i.map(_.asInstanceOf[Iterator[U]].asJava).asJava).asScala
JavaRDD
.fromRDD(rdd.zipPartitions(others.map(_.rdd): _*)(fn)(fakeClassTag[V]))(fakeClassTag[V])
}

/**
Expand Down
20 changes: 20 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,26 @@ abstract class RDD[T: ClassTag](
zipPartitions(rdd2, rdd3, rdd4, preservesPartitioning = false)(f)
}

/**
* A generic version of `zipPartitions` for an arbitrary number of RDDs. It may be type unsafe
* and other `zipPartitions` methods should be preferred.
*/
@Since("4.1.0")
def zipPartitions[V: ClassTag](preservesPartitioning: Boolean, rdds: RDD[_]*)(
f: Seq[Iterator[_]] => Iterator[V]): RDD[V] = withScope {
new ZippedPartitionsRDDN(sc, sc.clean(f), this +: rdds, preservesPartitioning)
}

/**
* A generic version of `zipPartitions` for an arbitrary number of RDDs. It may be type unsafe
* and other `zipPartitions` methods should be preferred.
*/
@Since("4.1.0")
def zipPartitions[V: ClassTag](rdds: RDD[_]*)(f: Seq[Iterator[_]] => Iterator[V]): RDD[V] =
withScope {
zipPartitions(preservesPartitioning = false, rdds: _*)(f)
}


// Actions (launch a job to return a value to the user program)

Expand Down
48 changes: 37 additions & 11 deletions core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ private[spark] class ZippedPartitionsPartition(
idx: Int,
@transient private val rdds: Seq[RDD[_]],
@transient val preferredLocations: Seq[String])
extends Partition {
extends Partition {

override val index: Int = idx
var partitionValues = rdds.map(rdd => rdd.partitions(idx))
Expand All @@ -46,7 +46,7 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag](
sc: SparkContext,
var rdds: Seq[RDD[_]],
preservesPartitioning: Boolean = false)
extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) {
extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) {

override val partitioner =
if (preservesPartitioning) firstParent[Any].partitioner else None
Expand Down Expand Up @@ -82,7 +82,7 @@ private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]
var rdd1: RDD[A],
var rdd2: RDD[B],
preservesPartitioning: Boolean = false)
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) {
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) {

override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
Expand All @@ -97,19 +97,19 @@ private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]
}
}

private[spark] class ZippedPartitionsRDD3
[A: ClassTag, B: ClassTag, C: ClassTag, V: ClassTag](
private[spark] class ZippedPartitionsRDD3[A: ClassTag, B: ClassTag, C: ClassTag, V: ClassTag](
sc: SparkContext,
var f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V],
var rdd1: RDD[A],
var rdd2: RDD[B],
var rdd3: RDD[C],
preservesPartitioning: Boolean = false)
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) {
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) {

override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
f(rdd1.iterator(partitions(0), context),
f(
rdd1.iterator(partitions(0), context),
rdd2.iterator(partitions(1), context),
rdd3.iterator(partitions(2), context))
}
Expand All @@ -123,20 +123,25 @@ private[spark] class ZippedPartitionsRDD3
}
}

private[spark] class ZippedPartitionsRDD4
[A: ClassTag, B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag](
private[spark] class ZippedPartitionsRDD4[
A: ClassTag,
B: ClassTag,
C: ClassTag,
D: ClassTag,
V: ClassTag](
sc: SparkContext,
var f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V],
var rdd1: RDD[A],
var rdd2: RDD[B],
var rdd3: RDD[C],
var rdd4: RDD[D],
preservesPartitioning: Boolean = false)
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) {
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) {

override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
f(rdd1.iterator(partitions(0), context),
f(
rdd1.iterator(partitions(0), context),
rdd2.iterator(partitions(1), context),
rdd3.iterator(partitions(2), context),
rdd4.iterator(partitions(3), context))
Expand All @@ -151,3 +156,24 @@ private[spark] class ZippedPartitionsRDD4
f = null
}
}

private[spark] class ZippedPartitionsRDDN[V: ClassTag](
sc: SparkContext,
var f: Seq[Iterator[_]] => Iterator[V],
var rddN: Seq[RDD[_]],
preservesPartitioning: Boolean = false)
extends ZippedPartitionsBaseRDD[V](sc, rddN, preservesPartitioning) {

override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
f(rdds.zip(partitions).map { case (rdd, partition) =>
rdd.iterator(partition, context)
})
}

override def clearDependencies(): Unit = {
super.clearDependencies()
rddN = null
f = null
}
}
Loading

0 comments on commit 72e8f05

Please sign in to comment.