Skip to content

Commit

Permalink
ZipPartitions for arbitrary number of RDDs.
Browse files Browse the repository at this point in the history
  • Loading branch information
kyle-winkelman committed Jan 25, 2025
1 parent 44966c9 commit f005ef7
Show file tree
Hide file tree
Showing 8 changed files with 545 additions and 159 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.java.function;

import java.io.Serializable;
import java.util.Iterator;

/**
* A function that takes three inputs and returns zero or more output records.
*/
@FunctionalInterface
public interface FlatMapFunction3<T1, T2, T3, R> extends Serializable {
Iterator<R> call(T1 t1, T2 t2, T3 t3) throws Exception;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.java.function;

import java.io.Serializable;
import java.util.Iterator;

/**
* A function that takes three inputs and returns zero or more output records.
*/
@FunctionalInterface
public interface FlatMapFunction4<T1, T2, T3, T4, R> extends Serializable {
Iterator<R> call(T1 t1, T2 t2, T3 t3, T4 t4) throws Exception;
}
80 changes: 72 additions & 8 deletions core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.{lang => jl}
import java.lang.{Iterable => JIterable}
import java.util.{Comparator, Iterator => JIterator, List => JList, Map => JMap}

import scala.annotation.varargs
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag

Expand Down Expand Up @@ -308,20 +309,83 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classTag))(classTag, other.classTag)
}

/**
* Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by
* applying a function to the zipped partitions. Assumes that all the RDDs have the
* *same number of partitions*, but does *not* require them to have the same number
* of elements in each partition.
/*
* Zip this RDD's partitions with one other RDD and return a new RDD by applying a function to
* the zipped partitions. Assumes that both the RDDs have the *same number of partitions*, but
* does *not* require them to have the same number of elements in each partition.
*/
def zipPartitions[U, V](
other: JavaRDDLike[U, _],
f: FlatMapFunction2[JIterator[T], JIterator[U], V]): JavaRDD[V] = {
def fn: (Iterator[T], Iterator[U]) => Iterator[V] = {
(x: Iterator[T], y: Iterator[U]) => f.call(x.asJava, y.asJava).asScala
def fn: (Iterator[T], Iterator[U]) => Iterator[V] = { (x: Iterator[T], y: Iterator[U]) =>
f.call(x.asJava, y.asJava).asScala
}
JavaRDD
.fromRDD(rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V])
}

/**
* Zip this RDD's partitions with two more RDDs and return a new RDD by applying a function to
* the zipped partitions. Assumes that all the RDDs have the *same number of partitions*, but
* does *not* require them to have the same number of elements in each partition.
*/
@Since("4.1.0")
def zipPartitions[U1, U2, V](
other1: JavaRDDLike[U1, _],
other2: JavaRDDLike[U2, _],
f: FlatMapFunction3[JIterator[T], JIterator[U1], JIterator[U2], V]): JavaRDD[V] = {
def fn: (Iterator[T], Iterator[U1], Iterator[U2]) => Iterator[V] =
(t: Iterator[T], u1: Iterator[U1], u2: Iterator[U2]) =>
f.call(t.asJava, u1.asJava, u2.asJava).asScala
JavaRDD.fromRDD(
rdd.zipPartitions(other1.rdd, other2.rdd)(fn)(
other1.classTag,
other2.classTag,
fakeClassTag[V]))(fakeClassTag[V])
}

/**
* Zip this RDD's partitions with three more RDDs and return a new RDD by applying a function to
* the zipped partitions. Assumes that all the RDDs have the *same number of partitions*, but
* does *not* require them to have the same number of elements in each partition.
*/
@Since("4.1.0")
def zipPartitions[U1, U2, U3, V](
other1: JavaRDDLike[U1, _],
other2: JavaRDDLike[U2, _],
other3: JavaRDDLike[U3, _],
f: FlatMapFunction4[JIterator[T], JIterator[U1], JIterator[U2], JIterator[U3], V])
: JavaRDD[V] = {
def fn: (Iterator[T], Iterator[U1], Iterator[U2], Iterator[U3]) => Iterator[V] =
(t: Iterator[T], u1: Iterator[U1], u2: Iterator[U2], u3: Iterator[U3]) =>
f.call(t.asJava, u1.asJava, u2.asJava, u3.asJava).asScala
JavaRDD.fromRDD(
rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V])
rdd.zipPartitions(other1.rdd, other2.rdd, other3.rdd)(fn)(
other1.classTag,
other2.classTag,
other3.classTag,
fakeClassTag[V]))(fakeClassTag[V])
}

/**
* Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by applying a
* function to the zipped partitions. Assumes that all the RDDs have the *same number of
* partitions*, but does *not* require them to have the same number of elements in each
* partition.
*
* @note
* A generic version of `zipPartitions` for an arbitrary number of RDDs. It may be type unsafe
* and other `zipPartitions` methods should be preferred.
*/
@Since("4.1.0")
@varargs
def zipPartitions[U, V](
f: FlatMapFunction[JList[JIterator[U]], V],
others: JavaRDDLike[_, _]*): JavaRDD[V] = {
def fn: Seq[Iterator[_]] => Iterator[V] =
(i: Seq[Iterator[_]]) => f.call(i.map(_.asInstanceOf[Iterator[U]].asJava).asJava).asScala
JavaRDD
.fromRDD(rdd.zipPartitions(others.map(_.rdd): _*)(fn)(fakeClassTag[V]))(fakeClassTag[V])
}

/**
Expand Down
20 changes: 20 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,26 @@ abstract class RDD[T: ClassTag](
zipPartitions(rdd2, rdd3, rdd4, preservesPartitioning = false)(f)
}

/**
* A generic version of `zipPartitions` for an arbitrary number of RDDs. It may be type unsafe
* and other `zipPartitions` methods should be preferred.
*/
@Since("4.1.0")
def zipPartitions[V: ClassTag](preservesPartitioning: Boolean, rdds: RDD[_]*)(
f: Seq[Iterator[_]] => Iterator[V]): RDD[V] = withScope {
new ZippedPartitionsRDDN(sc, sc.clean(f), this +: rdds, preservesPartitioning)
}

/**
* A generic version of `zipPartitions` for an arbitrary number of RDDs. It may be type unsafe
* and other `zipPartitions` methods should be preferred.
*/
@Since("4.1.0")
def zipPartitions[V: ClassTag](rdds: RDD[_]*)(f: Seq[Iterator[_]] => Iterator[V]): RDD[V] =
withScope {
zipPartitions(preservesPartitioning = false, rdds: _*)(f)
}


// Actions (launch a job to return a value to the user program)

Expand Down
48 changes: 37 additions & 11 deletions core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ private[spark] class ZippedPartitionsPartition(
idx: Int,
@transient private val rdds: Seq[RDD[_]],
@transient val preferredLocations: Seq[String])
extends Partition {
extends Partition {

override val index: Int = idx
var partitionValues = rdds.map(rdd => rdd.partitions(idx))
Expand All @@ -46,7 +46,7 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag](
sc: SparkContext,
var rdds: Seq[RDD[_]],
preservesPartitioning: Boolean = false)
extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) {
extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) {

override val partitioner =
if (preservesPartitioning) firstParent[Any].partitioner else None
Expand Down Expand Up @@ -82,7 +82,7 @@ private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]
var rdd1: RDD[A],
var rdd2: RDD[B],
preservesPartitioning: Boolean = false)
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) {
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) {

override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
Expand All @@ -97,19 +97,19 @@ private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]
}
}

private[spark] class ZippedPartitionsRDD3
[A: ClassTag, B: ClassTag, C: ClassTag, V: ClassTag](
private[spark] class ZippedPartitionsRDD3[A: ClassTag, B: ClassTag, C: ClassTag, V: ClassTag](
sc: SparkContext,
var f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V],
var rdd1: RDD[A],
var rdd2: RDD[B],
var rdd3: RDD[C],
preservesPartitioning: Boolean = false)
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) {
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) {

override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
f(rdd1.iterator(partitions(0), context),
f(
rdd1.iterator(partitions(0), context),
rdd2.iterator(partitions(1), context),
rdd3.iterator(partitions(2), context))
}
Expand All @@ -123,20 +123,25 @@ private[spark] class ZippedPartitionsRDD3
}
}

private[spark] class ZippedPartitionsRDD4
[A: ClassTag, B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag](
private[spark] class ZippedPartitionsRDD4[
A: ClassTag,
B: ClassTag,
C: ClassTag,
D: ClassTag,
V: ClassTag](
sc: SparkContext,
var f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V],
var rdd1: RDD[A],
var rdd2: RDD[B],
var rdd3: RDD[C],
var rdd4: RDD[D],
preservesPartitioning: Boolean = false)
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) {
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) {

override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
f(rdd1.iterator(partitions(0), context),
f(
rdd1.iterator(partitions(0), context),
rdd2.iterator(partitions(1), context),
rdd3.iterator(partitions(2), context),
rdd4.iterator(partitions(3), context))
Expand All @@ -151,3 +156,24 @@ private[spark] class ZippedPartitionsRDD4
f = null
}
}

private[spark] class ZippedPartitionsRDDN[V: ClassTag](
sc: SparkContext,
var f: Seq[Iterator[_]] => Iterator[V],
var rddN: Seq[RDD[_]],
preservesPartitioning: Boolean = false)
extends ZippedPartitionsBaseRDD[V](sc, rddN, preservesPartitioning) {

override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
f(rdds.zip(partitions).map { case (rdd, partition) =>
rdd.iterator(partition, context)
})
}

override def clearDependencies(): Unit = {
super.clearDependencies()
rddN = null
f = null
}
}
93 changes: 74 additions & 19 deletions core/src/test/java/test/org/apache/spark/Java8RDDAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -274,25 +274,45 @@ public void zip() {
}

@Test
public void zipPartitions() {
JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6), 2);
JavaRDD<String> rdd2 = sc.parallelize(Arrays.asList("1", "2", "3", "4"), 2);
FlatMapFunction2<Iterator<Integer>, Iterator<String>, Integer> sizesFn =
(Iterator<Integer> i, Iterator<String> s) -> {
int sizeI = 0;
while (i.hasNext()) {
sizeI += 1;
i.next();
}
int sizeS = 0;
while (s.hasNext()) {
sizeS += 1;
s.next();
}
return Arrays.asList(sizeI, sizeS).iterator();
};
JavaRDD<Integer> sizes = rdd1.zipPartitions(rdd2, sizesFn);
Assertions.assertEquals("[3, 2, 3, 2]", sizes.collect().toString());
public void zipPartitions2() {
JavaRDD<String> rdd1 = sc.parallelize(Arrays.asList("a", "b", "c", "d"), 2);
JavaRDD<String> rdd2 = sc.parallelize(Arrays.asList("e", "f", "g", "h"), 2);
JavaRDD<String> zipped = rdd1.zipPartitions(
rdd2, ZipPartitionsFunction.ZIP_PARTITIONS_2_FUNCTION);
Assertions.assertEquals(Arrays.asList("abef", "cdgh"), zipped.collect());
}

@Test
public void zipPartitions3() {
JavaRDD<String> rdd1 = sc.parallelize(Arrays.asList("a", "b", "c", "d"), 2);
JavaRDD<String> rdd2 = sc.parallelize(Arrays.asList("e", "f", "g", "h"), 2);
JavaRDD<String> rdd3 = sc.parallelize(Arrays.asList("i", "j", "k", "l"), 2);
JavaRDD<String> zipped = rdd1.zipPartitions(
rdd2, rdd3, ZipPartitionsFunction.ZIP_PARTITIONS_3_FUNCTION);
Assertions.assertEquals(Arrays.asList("abefij", "cdghkl"), zipped.collect());
}

@Test
public void zipPartitions4() {
JavaRDD<String> rdd1 = sc.parallelize(Arrays.asList("a", "b", "c", "d"), 2);
JavaRDD<String> rdd2 = sc.parallelize(Arrays.asList("e", "f", "g", "h"), 2);
JavaRDD<String> rdd3 = sc.parallelize(Arrays.asList("i", "j", "k", "l"), 2);
JavaRDD<String> rdd4 = sc.parallelize(Arrays.asList("m", "n", "o", "p"), 2);
JavaRDD<String> zipped = rdd1.zipPartitions(
rdd2, rdd3, rdd4, ZipPartitionsFunction.ZIP_PARTITIONS_4_FUNCTION);
Assertions.assertEquals(Arrays.asList("abefijmn", "cdghklop"), zipped.collect());
}

@Test
public void zipPartitionsN() {
JavaRDD<String> rdd1 = sc.parallelize(Arrays.asList("a", "b", "c", "d"), 2);
JavaRDD<String> rdd2 = sc.parallelize(Arrays.asList("e", "f", "g", "h"), 2);
JavaRDD<String> rdd3 = sc.parallelize(Arrays.asList("i", "j", "k", "l"), 2);
JavaRDD<String> rdd4 = sc.parallelize(Arrays.asList("m", "n", "o", "p"), 2);
JavaRDD<String> rdd5 = sc.parallelize(Arrays.asList("q", "r", "s", "t"), 2);
JavaRDD<String> zipped = rdd1.zipPartitions(
ZipPartitionsFunction.ZIP_PARTITIONS_N_FUNCTION, rdd2, rdd3, rdd4, rdd5);
Assertions.assertEquals(Arrays.asList("abefijmnqr", "cdghklopst"), zipped.collect());
}

@Test
Expand Down Expand Up @@ -348,4 +368,39 @@ public void collectAsMapWithIntArrayValues() {
pairRDD.collect(); // Works fine
pairRDD.collectAsMap(); // Used to crash with ClassCastException
}

private static class ZipPartitionsFunction
implements FlatMapFunction<List<Iterator<String>>, String> {

private static final ZipPartitionsFunction ZIP_PARTITIONS_N_FUNCTION =
new ZipPartitionsFunction();

private static final FlatMapFunction2<Iterator<String>, Iterator<String>, String>
ZIP_PARTITIONS_2_FUNCTION =
(Iterator<String> i1, Iterator<String> i2) ->
ZIP_PARTITIONS_N_FUNCTION.call(Arrays.asList(i1, i2));

private static final FlatMapFunction3<
Iterator<String>, Iterator<String>, Iterator<String>, String>
ZIP_PARTITIONS_3_FUNCTION =
(Iterator<String> i1, Iterator<String> i2, Iterator<String> i3) ->
ZIP_PARTITIONS_N_FUNCTION.call(Arrays.asList(i1, i2, i3));

private static final FlatMapFunction4<
Iterator<String>, Iterator<String>, Iterator<String>, Iterator<String>, String>
ZIP_PARTITIONS_4_FUNCTION =
(Iterator<String> i1, Iterator<String> i2, Iterator<String> i3, Iterator<String> i4) ->
ZIP_PARTITIONS_N_FUNCTION.call(Arrays.asList(i1, i2, i3, i4));

@Override
public Iterator<String> call(List<Iterator<String>> iterators) {
StringBuilder stringBuilder = new StringBuilder();
for (Iterator<String> iterator : iterators) {
while (iterator.hasNext()) {
stringBuilder.append(iterator.next());
}
}
return Collections.singleton(stringBuilder.toString()).iterator();
}
}
}
Loading

0 comments on commit f005ef7

Please sign in to comment.