Skip to content

Commit 17dade9

Browse files
committed
keylock
1 parent 9482635 commit 17dade9

File tree

8 files changed

+95
-12
lines changed

8 files changed

+95
-12
lines changed

client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java

+8-3
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
import scala.Tuple2;
3737

3838
import com.google.common.annotations.VisibleForTesting;
39-
import org.apache.commons.io.output.ByteArrayOutputStream;
4039
import org.apache.spark.BarrierTaskContext;
4140
import org.apache.spark.SparkConf;
4241
import org.apache.spark.SparkContext;
@@ -55,11 +54,11 @@
5554
import org.apache.spark.sql.execution.UnsafeRowSerializer;
5655
import org.apache.spark.sql.execution.metric.SQLMetric;
5756
import org.apache.spark.storage.BlockManagerId;
58-
import org.apache.spark.util.KeyLock;
5957
import org.slf4j.Logger;
6058
import org.slf4j.LoggerFactory;
6159

6260
import org.apache.celeborn.client.ShuffleClient;
61+
import org.apache.celeborn.client.util.KeyLock;
6362
import org.apache.celeborn.common.CelebornConf;
6463
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse;
6564
import org.apache.celeborn.common.util.JavaUtils;
@@ -390,7 +389,13 @@ public static byte[] serializeGetReducerFileGroupResponse(
390389
GetReducerFileGroupResponse.class));
391390

392391
CompressionCodec codec = CompressionCodec.createCodec(sparkContext.conf());
393-
ByteArrayOutputStream out = new ByteArrayOutputStream();
392+
// Using `org.apache.commons.io.output.ByteArrayOutputStream` instead of the standard
393+
// one
394+
// This implementation doesn't reallocate the whole memory block but allocates
395+
// additional buffers. This way no buffers need to be garbage collected and
396+
// the contents don't have to be copied to the new buffer.
397+
org.apache.commons.io.output.ByteArrayOutputStream out =
398+
new org.apache.commons.io.output.ByteArrayOutputStream();
394399
try (ObjectOutputStream oos =
395400
new ObjectOutputStream(codec.compressedOutputStream(out))) {
396401
oos.writeObject(broadcast);

client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java

+8-3
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
import scala.Tuple2;
3434

3535
import com.google.common.annotations.VisibleForTesting;
36-
import org.apache.commons.io.output.ByteArrayOutputStream;
3736
import org.apache.spark.BarrierTaskContext;
3837
import org.apache.spark.MapOutputTrackerMaster;
3938
import org.apache.spark.SparkConf;
@@ -58,11 +57,11 @@
5857
import org.apache.spark.sql.execution.UnsafeRowSerializer;
5958
import org.apache.spark.sql.execution.metric.SQLMetric;
6059
import org.apache.spark.storage.BlockManagerId;
61-
import org.apache.spark.util.KeyLock;
6260
import org.slf4j.Logger;
6361
import org.slf4j.LoggerFactory;
6462

6563
import org.apache.celeborn.client.ShuffleClient;
64+
import org.apache.celeborn.client.util.KeyLock;
6665
import org.apache.celeborn.common.CelebornConf;
6766
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse;
6867
import org.apache.celeborn.common.util.JavaUtils;
@@ -506,7 +505,13 @@ public static byte[] serializeGetReducerFileGroupResponse(
506505
GetReducerFileGroupResponse.class));
507506

508507
CompressionCodec codec = CompressionCodec.createCodec(sparkContext.conf());
509-
ByteArrayOutputStream out = new ByteArrayOutputStream();
508+
// Using `org.apache.commons.io.output.ByteArrayOutputStream` instead of the standard
509+
// one
510+
// This implementation doesn't reallocate the whole memory block but allocates
511+
// additional buffers. This way no buffers need to be garbage collected and
512+
// the contents don't have to be copied to the new buffer.
513+
org.apache.commons.io.output.ByteArrayOutputStream out =
514+
new org.apache.commons.io.output.ByteArrayOutputStream();
510515
try (ObjectOutputStream oos =
511516
new ObjectOutputStream(codec.compressedOutputStream(out))) {
512517
oos.writeObject(broadcast);

client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@ import org.apache.spark.shuffle.celeborn.CelebornShuffleReader.streamCreatorPool
3535
import org.apache.spark.util.CompletionIterator
3636
import org.apache.spark.util.collection.ExternalSorter
3737

38-
import org.apache.celeborn.client.{ClientUtils, ShuffleClient}
38+
import org.apache.celeborn.client.ShuffleClient
3939
import org.apache.celeborn.client.ShuffleClientImpl.ReduceFileGroups
4040
import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback}
41+
import org.apache.celeborn.client.util.ClientUtils
4142
import org.apache.celeborn.common.CelebornConf
4243
import org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRetryAbleException}
4344
import org.apache.celeborn.common.network.client.TransportClient

client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ import com.google.common.cache.{Cache, CacheBuilder}
3939

4040
import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers}
4141
import org.apache.celeborn.client.listener.WorkerStatusListener
42+
import org.apache.celeborn.client.util.ClientUtils
4243
import org.apache.celeborn.common.CelebornConf
4344
import org.apache.celeborn.common.client.MasterClient
4445
import org.apache.celeborn.common.identity.{IdentityProvider, UserIdentifier}

client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ import java.util.concurrent.atomic.AtomicInteger
2525
import scala.collection.JavaConverters._
2626
import scala.collection.mutable
2727

28-
import org.apache.celeborn.client.{ClientUtils, LifecycleManager, ShuffleCommittedInfo, WorkerStatusTracker}
28+
import org.apache.celeborn.client.{LifecycleManager, ShuffleCommittedInfo, WorkerStatusTracker}
2929
import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
3030
import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers}
31+
import org.apache.celeborn.client.util.ClientUtils
3132
import org.apache.celeborn.common.CelebornConf
3233
import org.apache.celeborn.common.internal.Logging
3334
import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo}

client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@ import scala.collection.mutable
2828
import com.google.common.cache.{Cache, CacheBuilder}
2929
import com.google.common.collect.Sets
3030

31-
import org.apache.celeborn.client.{ClientUtils, LifecycleManager, ShuffleCommittedInfo, WorkerStatusTracker}
31+
import org.apache.celeborn.client.{LifecycleManager, ShuffleCommittedInfo, WorkerStatusTracker}
3232
import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
3333
import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers}
34+
import org.apache.celeborn.client.util.ClientUtils
3435
import org.apache.celeborn.common.CelebornConf
3536
import org.apache.celeborn.common.internal.Logging
3637
import org.apache.celeborn.common.meta.ShufflePartitionLocationInfo

client/src/main/scala/org/apache/celeborn/client/ClientUtils.scala client/src/main/scala/org/apache/celeborn/client/util/ClientUtils.scala

+2-3
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.celeborn.client
19-
20-
import java.util.Collections
18+
package org.apache.celeborn.client.util
2119

20+
import org.apache.celeborn.client.LifecycleManager
2221
import org.apache.celeborn.common.CelebornConf
2322
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
2423

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.celeborn.client.util
19+
20+
import java.util.concurrent.ConcurrentHashMap
21+
22+
/**
23+
* This class is copied from Apache Spark.
24+
* A special locking mechanism to provide locking with a given key. By providing the same key
25+
* (identity is tested using the `equals` method), we ensure there is only one `func` running at
26+
* the same time.
27+
*
28+
* @tparam K the type of key to identify a lock. This type must implement `equals` and `hashCode`
29+
* correctly as it will be the key type of an internal Map.
30+
*/
31+
class KeyLock[K] {
32+
33+
private val lockMap = new ConcurrentHashMap[K, AnyRef]()
34+
35+
private def acquireLock(key: K): Unit = {
36+
while (true) {
37+
val lock = lockMap.putIfAbsent(key, new Object)
38+
if (lock == null) return
39+
lock.synchronized {
40+
while (lockMap.get(key) eq lock) {
41+
lock.wait()
42+
}
43+
}
44+
}
45+
}
46+
47+
private def releaseLock(key: K): Unit = {
48+
val lock = lockMap.remove(key)
49+
lock.synchronized {
50+
lock.notifyAll()
51+
}
52+
}
53+
54+
/**
55+
* Run `func` under a lock identified by the given key. Multiple calls with the same key
56+
* (identity is tested using the `equals` method) will be locked properly to ensure there is only
57+
* one `func` running at the same time.
58+
*/
59+
def withLock[T](key: K)(func: => T): T = {
60+
if (key == null) {
61+
throw new NullPointerException("key must not be null")
62+
}
63+
acquireLock(key)
64+
try {
65+
func
66+
} finally {
67+
releaseLock(key)
68+
}
69+
}
70+
}

0 commit comments

Comments
 (0)