Skip to content

Commit cf133e6

Browse files
viiryacloud-fan
authored andcommitted
[SPARK-26604][CORE] Clean up channel registration for StreamManager
## What changes were proposed in this pull request? Now in `TransportRequestHandler.processStreamRequest`, when a stream request is processed, the stream id is not registered with the current channel in stream manager. It should do that so in case of that the channel gets terminated we can remove associated streams of stream requests too. This also cleans up channel registration in `StreamManager`. Since `StreamManager` doesn't register channel but only `OneForOneStreamManager` does it, this removes `registerChannel` from `StreamManager`. When `OneForOneStreamManager` goes to register stream, it will also register channel for the stream. ## How was this patch tested? Existing tests. Closes apache#23521 from viirya/SPARK-26604. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 2ebb79b commit cf133e6

File tree

9 files changed

+30
-31
lines changed

9 files changed

+30
-31
lines changed

common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java

-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ protected void channelRead0(
9090
ManagedBuffer buf;
9191
try {
9292
streamManager.checkAuthorization(client, msg.streamChunkId.streamId);
93-
streamManager.registerChannel(channel, msg.streamChunkId.streamId);
9493
buf = streamManager.getChunk(msg.streamChunkId.streamId, msg.streamChunkId.chunkIndex);
9594
} catch (Exception e) {
9695
logger.error(String.format("Error opening block %s for request from %s",

common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java

+14-11
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.concurrent.ConcurrentHashMap;
2424
import java.util.concurrent.atomic.AtomicLong;
2525

26+
import com.google.common.annotations.VisibleForTesting;
2627
import com.google.common.base.Preconditions;
2728
import io.netty.channel.Channel;
2829
import org.apache.commons.lang3.tuple.ImmutablePair;
@@ -49,7 +50,7 @@ private static class StreamState {
4950
final Iterator<ManagedBuffer> buffers;
5051

5152
// The channel associated to the stream
52-
Channel associatedChannel = null;
53+
final Channel associatedChannel;
5354

5455
// Used to keep track of the index of the buffer that the user has retrieved, just to ensure
5556
// that the caller only requests each chunk one at a time, in order.
@@ -58,9 +59,10 @@ private static class StreamState {
5859
// Used to keep track of the number of chunks being transferred and not finished yet.
5960
volatile long chunksBeingTransferred = 0L;
6061

61-
StreamState(String appId, Iterator<ManagedBuffer> buffers) {
62+
StreamState(String appId, Iterator<ManagedBuffer> buffers, Channel channel) {
6263
this.appId = appId;
6364
this.buffers = Preconditions.checkNotNull(buffers);
65+
this.associatedChannel = channel;
6466
}
6567
}
6668

@@ -71,13 +73,6 @@ public OneForOneStreamManager() {
7173
streams = new ConcurrentHashMap<>();
7274
}
7375

74-
@Override
75-
public void registerChannel(Channel channel, long streamId) {
76-
if (streams.containsKey(streamId)) {
77-
streams.get(streamId).associatedChannel = channel;
78-
}
79-
}
80-
8176
@Override
8277
public ManagedBuffer getChunk(long streamId, int chunkIndex) {
8378
StreamState state = streams.get(streamId);
@@ -195,11 +190,19 @@ public long chunksBeingTransferred() {
195190
*
196191
* If an app ID is provided, only callers who've authenticated with the given app ID will be
197192
* allowed to fetch from this stream.
193+
*
194+
* This method also associates the stream with a single client connection, which is guaranteed
195+
* to be the only reader of the stream. Once the connection is closed, the stream will never
196+
* be used again, enabling cleanup by `connectionTerminated`.
198197
*/
199-
public long registerStream(String appId, Iterator<ManagedBuffer> buffers) {
198+
public long registerStream(String appId, Iterator<ManagedBuffer> buffers, Channel channel) {
200199
long myStreamId = nextStreamId.getAndIncrement();
201-
streams.put(myStreamId, new StreamState(appId, buffers));
200+
streams.put(myStreamId, new StreamState(appId, buffers, channel));
202201
return myStreamId;
203202
}
204203

204+
@VisibleForTesting
205+
public int numStreamStates() {
206+
return streams.size();
207+
}
205208
}

common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java

-10
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,6 @@ public ManagedBuffer openStream(String streamId) {
6060
throw new UnsupportedOperationException();
6161
}
6262

63-
/**
64-
* Associates a stream with a single client connection, which is guaranteed to be the only reader
65-
* of the stream. The getChunk() method will be called serially on this connection and once the
66-
* connection is closed, the stream will never be used again, enabling cleanup.
67-
*
68-
* This must be called before the first getChunk() on the stream, but it may be invoked multiple
69-
* times with the same channel and stream id.
70-
*/
71-
public void registerChannel(Channel channel, long streamId) { }
72-
7363
/**
7464
* Indicates that the given channel has been terminated. After this occurs, we are guaranteed not
7565
* to read from the associated streams again, so any state can be cleaned up.

common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ public void handleChunkFetchRequest() throws Exception {
6464
managedBuffers.add(new TestManagedBuffer(20));
6565
managedBuffers.add(new TestManagedBuffer(30));
6666
managedBuffers.add(new TestManagedBuffer(40));
67-
long streamId = streamManager.registerStream("test-app", managedBuffers.iterator());
68-
streamManager.registerChannel(channel, streamId);
67+
long streamId = streamManager.registerStream("test-app", managedBuffers.iterator(), channel);
6968
TransportClient reverseClient = mock(TransportClient.class);
7069
ChunkFetchRequestHandler requestHandler = new ChunkFetchRequestHandler(reverseClient,
7170
rpcHandler.getStreamManager(), 2L);

common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java

+7-2
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,10 @@ public void handleStreamRequest() throws Exception {
5858
managedBuffers.add(new TestManagedBuffer(20));
5959
managedBuffers.add(new TestManagedBuffer(30));
6060
managedBuffers.add(new TestManagedBuffer(40));
61-
long streamId = streamManager.registerStream("test-app", managedBuffers.iterator());
62-
streamManager.registerChannel(channel, streamId);
61+
long streamId = streamManager.registerStream("test-app", managedBuffers.iterator(), channel);
62+
63+
assert streamManager.numStreamStates() == 1;
64+
6365
TransportClient reverseClient = mock(TransportClient.class);
6466
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient,
6567
rpcHandler, 2L);
@@ -94,5 +96,8 @@ public void handleStreamRequest() throws Exception {
9496
requestHandler.handle(request3);
9597
verify(channel, times(1)).close();
9698
assert responseAndPromisePairs.size() == 3;
99+
100+
streamManager.connectionTerminated(channel);
101+
assert streamManager.numStreamStates() == 0;
97102
}
98103
}

common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,15 @@ public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception {
3737
TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20));
3838
buffers.add(buffer1);
3939
buffers.add(buffer2);
40-
long streamId = manager.registerStream("appId", buffers.iterator());
4140

4241
Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS);
43-
manager.registerChannel(dummyChannel, streamId);
42+
manager.registerStream("appId", buffers.iterator(), dummyChannel);
43+
assert manager.numStreamStates() == 1;
4444

4545
manager.connectionTerminated(dummyChannel);
4646

4747
Mockito.verify(buffer1, Mockito.times(1)).release();
4848
Mockito.verify(buffer2, Mockito.times(1)).release();
49+
assert manager.numStreamStates() == 0;
4950
}
5051
}

common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ protected void handleMessage(
9292
OpenBlocks msg = (OpenBlocks) msgObj;
9393
checkAuth(client, msg.appId);
9494
long streamId = streamManager.registerStream(client.getClientId(),
95-
new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds));
95+
new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds), client.getChannel());
9696
if (logger.isTraceEnabled()) {
9797
logger.trace("Registered streamId {} with {} buffers for client {} from host {}",
9898
streamId,

common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ public void testOpenShuffleBlocks() {
103103
@SuppressWarnings("unchecked")
104104
ArgumentCaptor<Iterator<ManagedBuffer>> stream = (ArgumentCaptor<Iterator<ManagedBuffer>>)
105105
(ArgumentCaptor<?>) ArgumentCaptor.forClass(Iterator.class);
106-
verify(streamManager, times(1)).registerStream(anyString(), stream.capture());
106+
verify(streamManager, times(1)).registerStream(anyString(), stream.capture(),
107+
any());
107108
Iterator<ManagedBuffer> buffers = stream.getValue();
108109
assertEquals(block0Marker, buffers.next());
109110
assertEquals(block1Marker, buffers.next());

core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ class NettyBlockRpcServer(
5959
val blocksNum = openBlocks.blockIds.length
6060
val blocks = for (i <- (0 until blocksNum).view)
6161
yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i)))
62-
val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
62+
val streamId = streamManager.registerStream(appId, blocks.iterator.asJava,
63+
client.getChannel)
6364
logTrace(s"Registered streamId $streamId with $blocksNum buffers")
6465
responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer)
6566

0 commit comments

Comments
 (0)