Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CORE]ZipPartitions for arbitrary number of RDDs. #49659

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.java.function;

import java.io.Serializable;
import java.util.Iterator;

/**
* A function that takes three inputs and returns zero or more output records.
*/
@FunctionalInterface
public interface FlatMapFunction3<T1, T2, T3, R> extends Serializable {
Iterator<R> call(T1 t1, T2 t2, T3 t3) throws Exception;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.java.function;

import java.io.Serializable;
import java.util.Iterator;

/**
* A function that takes three inputs and returns zero or more output records.
*/
@FunctionalInterface
public interface FlatMapFunction4<T1, T2, T3, T4, R> extends Serializable {
Iterator<R> call(T1 t1, T2 t2, T3 t3, T4 t4) throws Exception;
}
80 changes: 72 additions & 8 deletions core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.{lang => jl}
import java.lang.{Iterable => JIterable}
import java.util.{Comparator, Iterator => JIterator, List => JList, Map => JMap}

import scala.annotation.varargs
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag

Expand Down Expand Up @@ -308,20 +309,83 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classTag))(classTag, other.classTag)
}

/**
* Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by
* applying a function to the zipped partitions. Assumes that all the RDDs have the
* *same number of partitions*, but does *not* require them to have the same number
* of elements in each partition.
/*
* Zip this RDD's partitions with one other RDD and return a new RDD by applying a function to
* the zipped partitions. Assumes that both the RDDs have the *same number of partitions*, but
* does *not* require them to have the same number of elements in each partition.
*/
def zipPartitions[U, V](
other: JavaRDDLike[U, _],
f: FlatMapFunction2[JIterator[T], JIterator[U], V]): JavaRDD[V] = {
def fn: (Iterator[T], Iterator[U]) => Iterator[V] = {
(x: Iterator[T], y: Iterator[U]) => f.call(x.asJava, y.asJava).asScala
def fn: (Iterator[T], Iterator[U]) => Iterator[V] = { (x: Iterator[T], y: Iterator[U]) =>
f.call(x.asJava, y.asJava).asScala
}
JavaRDD
.fromRDD(rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V])
}

/**
* Zip this RDD's partitions with two more RDDs and return a new RDD by applying a function to
* the zipped partitions. Assumes that all the RDDs have the *same number of partitions*, but
* does *not* require them to have the same number of elements in each partition.
*/
@Since("4.1.0")
def zipPartitions[U1, U2, V](
other1: JavaRDDLike[U1, _],
other2: JavaRDDLike[U2, _],
f: FlatMapFunction3[JIterator[T], JIterator[U1], JIterator[U2], V]): JavaRDD[V] = {
def fn: (Iterator[T], Iterator[U1], Iterator[U2]) => Iterator[V] =
(t: Iterator[T], u1: Iterator[U1], u2: Iterator[U2]) =>
f.call(t.asJava, u1.asJava, u2.asJava).asScala
JavaRDD.fromRDD(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be easily worked around. I wouldn't add this also considering that we're being conservative on RDD API

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment in regards to the entire PR or just the changes in JavaRDDLike? My long term goal was to add additional cogroup methods for 3, 4, and N number of KeyValueGroupedDatasets. I do not need all the logic from this PR for that goal, but thought it was a good small step in that direction.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is what my long term might look like master...kyle-winkelman:spark:everything (might have some noise in it, but it adds additional cogroup methods and SPARK-42349). If you would prefer I attempt to go straight for the big PR that does all the changes at once, I can repurpose this PR to target those changes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we use Dataset instead? We're promoting it over RDD API actually.

rdd.zipPartitions(other1.rdd, other2.rdd)(fn)(
other1.classTag,
other2.classTag,
fakeClassTag[V]))(fakeClassTag[V])
}

/**
* Zip this RDD's partitions with three more RDDs and return a new RDD by applying a function to
* the zipped partitions. Assumes that all the RDDs have the *same number of partitions*, but
* does *not* require them to have the same number of elements in each partition.
*/
@Since("4.1.0")
def zipPartitions[U1, U2, U3, V](
other1: JavaRDDLike[U1, _],
other2: JavaRDDLike[U2, _],
other3: JavaRDDLike[U3, _],
f: FlatMapFunction4[JIterator[T], JIterator[U1], JIterator[U2], JIterator[U3], V])
: JavaRDD[V] = {
def fn: (Iterator[T], Iterator[U1], Iterator[U2], Iterator[U3]) => Iterator[V] =
(t: Iterator[T], u1: Iterator[U1], u2: Iterator[U2], u3: Iterator[U3]) =>
f.call(t.asJava, u1.asJava, u2.asJava, u3.asJava).asScala
JavaRDD.fromRDD(
rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V])
rdd.zipPartitions(other1.rdd, other2.rdd, other3.rdd)(fn)(
other1.classTag,
other2.classTag,
other3.classTag,
fakeClassTag[V]))(fakeClassTag[V])
}

/**
* Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by applying a
* function to the zipped partitions. Assumes that all the RDDs have the *same number of
* partitions*, but does *not* require them to have the same number of elements in each
* partition.
*
* @note
* A generic version of `zipPartitions` for an arbitrary number of RDDs. It may be type unsafe
* and other `zipPartitions` methods should be preferred.
*/
@Since("4.1.0")
@varargs
def zipPartitions[U, V](
f: FlatMapFunction[JList[JIterator[U]], V],
others: JavaRDDLike[_, _]*): JavaRDD[V] = {
def fn: Seq[Iterator[_]] => Iterator[V] =
(i: Seq[Iterator[_]]) => f.call(i.map(_.asInstanceOf[Iterator[U]].asJava).asJava).asScala
JavaRDD
.fromRDD(rdd.zipPartitions(others.map(_.rdd): _*)(fn)(fakeClassTag[V]))(fakeClassTag[V])
}

/**
Expand Down
20 changes: 20 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,26 @@ abstract class RDD[T: ClassTag](
zipPartitions(rdd2, rdd3, rdd4, preservesPartitioning = false)(f)
}

/**
* A generic version of `zipPartitions` for an arbitrary number of RDDs. It may be type unsafe
* and other `zipPartitions` methods should be preferred.
*/
@Since("4.1.0")
def zipPartitions[V: ClassTag](preservesPartitioning: Boolean, rdds: RDD[_]*)(
f: Seq[Iterator[_]] => Iterator[V]): RDD[V] = withScope {
new ZippedPartitionsRDDN(sc, sc.clean(f), this +: rdds, preservesPartitioning)
}

/**
* A generic version of `zipPartitions` for an arbitrary number of RDDs. It may be type unsafe
* and other `zipPartitions` methods should be preferred.
*/
@Since("4.1.0")
def zipPartitions[V: ClassTag](rdds: RDD[_]*)(f: Seq[Iterator[_]] => Iterator[V]): RDD[V] =
withScope {
zipPartitions(preservesPartitioning = false, rdds: _*)(f)
}


// Actions (launch a job to return a value to the user program)

Expand Down
48 changes: 37 additions & 11 deletions core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ private[spark] class ZippedPartitionsPartition(
idx: Int,
@transient private val rdds: Seq[RDD[_]],
@transient val preferredLocations: Seq[String])
extends Partition {
extends Partition {

override val index: Int = idx
var partitionValues = rdds.map(rdd => rdd.partitions(idx))
Expand All @@ -46,7 +46,7 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag](
sc: SparkContext,
var rdds: Seq[RDD[_]],
preservesPartitioning: Boolean = false)
extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) {
extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) {

override val partitioner =
if (preservesPartitioning) firstParent[Any].partitioner else None
Expand Down Expand Up @@ -82,7 +82,7 @@ private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]
var rdd1: RDD[A],
var rdd2: RDD[B],
preservesPartitioning: Boolean = false)
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) {
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) {

override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
Expand All @@ -97,19 +97,19 @@ private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]
}
}

private[spark] class ZippedPartitionsRDD3
[A: ClassTag, B: ClassTag, C: ClassTag, V: ClassTag](
private[spark] class ZippedPartitionsRDD3[A: ClassTag, B: ClassTag, C: ClassTag, V: ClassTag](
sc: SparkContext,
var f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V],
var rdd1: RDD[A],
var rdd2: RDD[B],
var rdd3: RDD[C],
preservesPartitioning: Boolean = false)
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) {
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) {

override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
f(rdd1.iterator(partitions(0), context),
f(
rdd1.iterator(partitions(0), context),
rdd2.iterator(partitions(1), context),
rdd3.iterator(partitions(2), context))
}
Expand All @@ -123,20 +123,25 @@ private[spark] class ZippedPartitionsRDD3
}
}

private[spark] class ZippedPartitionsRDD4
[A: ClassTag, B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag](
private[spark] class ZippedPartitionsRDD4[
A: ClassTag,
B: ClassTag,
C: ClassTag,
D: ClassTag,
V: ClassTag](
sc: SparkContext,
var f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V],
var rdd1: RDD[A],
var rdd2: RDD[B],
var rdd3: RDD[C],
var rdd4: RDD[D],
preservesPartitioning: Boolean = false)
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) {
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) {

override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
f(rdd1.iterator(partitions(0), context),
f(
rdd1.iterator(partitions(0), context),
rdd2.iterator(partitions(1), context),
rdd3.iterator(partitions(2), context),
rdd4.iterator(partitions(3), context))
Expand All @@ -151,3 +156,24 @@ private[spark] class ZippedPartitionsRDD4
f = null
}
}

private[spark] class ZippedPartitionsRDDN[V: ClassTag](
sc: SparkContext,
var f: Seq[Iterator[_]] => Iterator[V],
var rddN: Seq[RDD[_]],
preservesPartitioning: Boolean = false)
extends ZippedPartitionsBaseRDD[V](sc, rddN, preservesPartitioning) {

override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
f(rdds.zip(partitions).map { case (rdd, partition) =>
rdd.iterator(partition, context)
})
}

override def clearDependencies(): Unit = {
super.clearDependencies()
rddN = null
f = null
}
}
93 changes: 74 additions & 19 deletions core/src/test/java/test/org/apache/spark/Java8RDDAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -274,25 +274,45 @@ public void zip() {
}

@Test
public void zipPartitions() {
JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6), 2);
JavaRDD<String> rdd2 = sc.parallelize(Arrays.asList("1", "2", "3", "4"), 2);
FlatMapFunction2<Iterator<Integer>, Iterator<String>, Integer> sizesFn =
(Iterator<Integer> i, Iterator<String> s) -> {
int sizeI = 0;
while (i.hasNext()) {
sizeI += 1;
i.next();
}
int sizeS = 0;
while (s.hasNext()) {
sizeS += 1;
s.next();
}
return Arrays.asList(sizeI, sizeS).iterator();
};
JavaRDD<Integer> sizes = rdd1.zipPartitions(rdd2, sizesFn);
Assertions.assertEquals("[3, 2, 3, 2]", sizes.collect().toString());
public void zipPartitions2() {
JavaRDD<String> rdd1 = sc.parallelize(Arrays.asList("a", "b", "c", "d"), 2);
JavaRDD<String> rdd2 = sc.parallelize(Arrays.asList("e", "f", "g", "h"), 2);
JavaRDD<String> zipped = rdd1.zipPartitions(
rdd2, ZipPartitionsFunction.ZIP_PARTITIONS_2_FUNCTION);
Assertions.assertEquals(Arrays.asList("abef", "cdgh"), zipped.collect());
}

@Test
public void zipPartitions3() {
JavaRDD<String> rdd1 = sc.parallelize(Arrays.asList("a", "b", "c", "d"), 2);
JavaRDD<String> rdd2 = sc.parallelize(Arrays.asList("e", "f", "g", "h"), 2);
JavaRDD<String> rdd3 = sc.parallelize(Arrays.asList("i", "j", "k", "l"), 2);
JavaRDD<String> zipped = rdd1.zipPartitions(
rdd2, rdd3, ZipPartitionsFunction.ZIP_PARTITIONS_3_FUNCTION);
Assertions.assertEquals(Arrays.asList("abefij", "cdghkl"), zipped.collect());
}

@Test
public void zipPartitions4() {
JavaRDD<String> rdd1 = sc.parallelize(Arrays.asList("a", "b", "c", "d"), 2);
JavaRDD<String> rdd2 = sc.parallelize(Arrays.asList("e", "f", "g", "h"), 2);
JavaRDD<String> rdd3 = sc.parallelize(Arrays.asList("i", "j", "k", "l"), 2);
JavaRDD<String> rdd4 = sc.parallelize(Arrays.asList("m", "n", "o", "p"), 2);
JavaRDD<String> zipped = rdd1.zipPartitions(
rdd2, rdd3, rdd4, ZipPartitionsFunction.ZIP_PARTITIONS_4_FUNCTION);
Assertions.assertEquals(Arrays.asList("abefijmn", "cdghklop"), zipped.collect());
}

@Test
public void zipPartitionsN() {
JavaRDD<String> rdd1 = sc.parallelize(Arrays.asList("a", "b", "c", "d"), 2);
JavaRDD<String> rdd2 = sc.parallelize(Arrays.asList("e", "f", "g", "h"), 2);
JavaRDD<String> rdd3 = sc.parallelize(Arrays.asList("i", "j", "k", "l"), 2);
JavaRDD<String> rdd4 = sc.parallelize(Arrays.asList("m", "n", "o", "p"), 2);
JavaRDD<String> rdd5 = sc.parallelize(Arrays.asList("q", "r", "s", "t"), 2);
JavaRDD<String> zipped = rdd1.zipPartitions(
ZipPartitionsFunction.ZIP_PARTITIONS_N_FUNCTION, rdd2, rdd3, rdd4, rdd5);
Assertions.assertEquals(Arrays.asList("abefijmnqr", "cdghklopst"), zipped.collect());
}

@Test
Expand Down Expand Up @@ -348,4 +368,39 @@ public void collectAsMapWithIntArrayValues() {
pairRDD.collect(); // Works fine
pairRDD.collectAsMap(); // Used to crash with ClassCastException
}

private static class ZipPartitionsFunction
implements FlatMapFunction<List<Iterator<String>>, String> {

private static final ZipPartitionsFunction ZIP_PARTITIONS_N_FUNCTION =
new ZipPartitionsFunction();

private static final FlatMapFunction2<Iterator<String>, Iterator<String>, String>
ZIP_PARTITIONS_2_FUNCTION =
(Iterator<String> i1, Iterator<String> i2) ->
ZIP_PARTITIONS_N_FUNCTION.call(Arrays.asList(i1, i2));

private static final FlatMapFunction3<
Iterator<String>, Iterator<String>, Iterator<String>, String>
ZIP_PARTITIONS_3_FUNCTION =
(Iterator<String> i1, Iterator<String> i2, Iterator<String> i3) ->
ZIP_PARTITIONS_N_FUNCTION.call(Arrays.asList(i1, i2, i3));

private static final FlatMapFunction4<
Iterator<String>, Iterator<String>, Iterator<String>, Iterator<String>, String>
ZIP_PARTITIONS_4_FUNCTION =
(Iterator<String> i1, Iterator<String> i2, Iterator<String> i3, Iterator<String> i4) ->
ZIP_PARTITIONS_N_FUNCTION.call(Arrays.asList(i1, i2, i3, i4));

@Override
public Iterator<String> call(List<Iterator<String>> iterators) {
StringBuilder stringBuilder = new StringBuilder();
for (Iterator<String> iterator : iterators) {
while (iterator.hasNext()) {
stringBuilder.append(iterator.next());
}
}
return Collections.singleton(stringBuilder.toString()).iterator();
}
}
}
Loading
Loading