Skip to content

Commit cca1cb3

Browse files
committed
clear
1 parent 865690d commit cca1cb3

File tree

5 files changed

+18
-7
lines changed

5 files changed

+18
-7
lines changed

client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,8 @@ public static void addSparkListener(SparkListener listener) {
363363
@VisibleForTesting
364364
public static AtomicInteger getReducerFileGroupResponseBroadcastNum = new AtomicInteger();
365365

366-
protected static Map<Integer, Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]>>
366+
@VisibleForTesting
367+
public static Map<Integer, Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]>>
367368
getReducerFileGroupResponseBroadcasts = JavaUtils.newConcurrentHashMap();
368369

369370
public static byte[] serializeGetReducerFileGroupResponse(

client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,8 @@ public static void addSparkListener(SparkListener listener) {
480480
@VisibleForTesting
481481
public static AtomicInteger getReducerFileGroupResponseBroadcastNum = new AtomicInteger();
482482

483-
protected static Map<Integer, Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]>>
483+
@VisibleForTesting
484+
public static Map<Integer, Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]>>
484485
getReducerFileGroupResponseBroadcasts = JavaUtils.newConcurrentHashMap();
485486

486487
public static byte[] serializeGetReducerFileGroupResponse(

tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornHashSuite.scala

+3
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class CelebornHashSuite extends AnyFunSuite
6868
}
6969

7070
test("celeborn spark integration test - GetReducerFileGroupResponse broadcast") {
71+
SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
7172
SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
7273
val sparkConf = new SparkConf().setAppName("celeborn-demo").setMaster("local[2]")
7374
.set(
@@ -101,5 +102,7 @@ class CelebornHashSuite extends AnyFunSuite
101102
assert(SparkUtils.getReducerFileGroupResponseBroadcastNum.get() > 0)
102103

103104
celebornSparkSession.stop()
105+
SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
106+
SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
104107
}
105108
}

tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornSortSuite.scala

+3
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class CelebornSortSuite extends AnyFunSuite
6969
}
7070

7171
test("celeborn spark integration test - GetReducerFileGroupResponse broadcast") {
72+
SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
7273
SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
7374
val sparkConf = new SparkConf().setAppName("celeborn-demo").setMaster("local[2]")
7475
.set(s"spark.${CelebornConf.CLIENT_PUSH_SORT_RANDOMIZE_PARTITION_ENABLED.key}", "false")
@@ -104,5 +105,7 @@ class CelebornSortSuite extends AnyFunSuite
104105
assert(SparkUtils.getReducerFileGroupResponseBroadcastNum.get() > 0)
105106

106107
celebornSparkSession.stop()
108+
SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
109+
SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
107110
}
108111
}

tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala

+8-5
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,10 @@ class SparkUtilsSuite extends AnyFunSuite
173173
.getOrCreate()
174174

175175
try {
176+
val shuffleId = Integer.MAX_VALUE
176177
val getReducerFileGroupResponse = GetReducerFileGroupResponse(
177178
StatusCode.SUCCESS,
178-
Map(Integer.valueOf(1) -> Set(new PartitionLocation(
179+
Map(Integer.valueOf(shuffleId) -> Set(new PartitionLocation(
179180
0,
180181
1,
181182
"localhost",
@@ -185,24 +186,26 @@ class SparkUtilsSuite extends AnyFunSuite
185186
4,
186187
PartitionLocation.Mode.REPLICA)).asJava).asJava,
187188
Array(1),
188-
Set(Integer.valueOf(1)).asJava)
189+
Set(Integer.valueOf(shuffleId)).asJava)
189190

190191
val serializedBytes =
191-
SparkUtils.serializeGetReducerFileGroupResponse(1, getReducerFileGroupResponse)
192+
SparkUtils.serializeGetReducerFileGroupResponse(shuffleId, getReducerFileGroupResponse)
192193
assert(serializedBytes != null && serializedBytes.length > 0)
193194
val broadcast = SparkUtils.getReducerFileGroupResponseBroadcasts.values().asScala.head._1
194195
assert(broadcast.value == getReducerFileGroupResponse)
195196

196197
val deserializedGetReducerFileGroupResponse =
197-
SparkUtils.deserializeGetReducerFileGroupResponse(1, serializedBytes)
198+
SparkUtils.deserializeGetReducerFileGroupResponse(shuffleId, serializedBytes)
198199
assert(deserializedGetReducerFileGroupResponse == getReducerFileGroupResponse)
199200

200201
assert(!SparkUtils.getReducerFileGroupResponseBroadcasts.isEmpty)
201-
SparkUtils.invalidateSerializedGetReducerFileGroupResponse(1)
202+
SparkUtils.invalidateSerializedGetReducerFileGroupResponse(shuffleId)
202203
assert(SparkUtils.getReducerFileGroupResponseBroadcasts.isEmpty)
203204
assert(!broadcast.isValid)
204205
} finally {
205206
sparkSession.stop()
207+
SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
208+
SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
206209
}
207210
}
208211
}

0 commit comments

Comments
 (0)