Skip to content

Commit f0c6178

Browse files
committed
[FLINK-38267][checkpoint] Only call channel state rescaling logic for exchange with channel state to avoid UnsupportedOperationException
1 parent d6bb5c9 commit f0c6178

File tree

8 files changed

+923
-47
lines changed

8 files changed

+923
-47
lines changed

flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/InflightDataRescalingDescriptor.java

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.io.ObjectStreamException;
2121
import java.io.Serializable;
2222
import java.util.Arrays;
23+
import java.util.Collections;
2324
import java.util.Objects;
2425
import java.util.Set;
2526

@@ -45,11 +46,11 @@ public InflightDataRescalingDescriptor(
4546
}
4647

4748
public int[] getOldSubtaskIndexes(int gateOrPartitionIndex) {
48-
return gateOrPartitionDescriptors[gateOrPartitionIndex].oldSubtaskIndexes;
49+
return gateOrPartitionDescriptors[gateOrPartitionIndex].getOldSubtaskInstances();
4950
}
5051

5152
public RescaleMappings getChannelMapping(int gateOrPartitionIndex) {
52-
return gateOrPartitionDescriptors[gateOrPartitionIndex].rescaledChannelsMappings;
53+
return gateOrPartitionDescriptors[gateOrPartitionIndex].getRescaleMappings();
5354
}
5455

5556
public boolean isAmbiguous(int gateOrPartitionIndex, int oldSubtaskIndex) {
@@ -112,6 +113,28 @@ public String toString() {
112113
*/
113114
public static class InflightDataGateOrPartitionRescalingDescriptor implements Serializable {
114115

116+
public static final InflightDataGateOrPartitionRescalingDescriptor NO_STATE =
117+
new InflightDataGateOrPartitionRescalingDescriptor(
118+
new int[0],
119+
RescaleMappings.identity(0, 0),
120+
Collections.emptySet(),
121+
MappingType.IDENTITY) {
122+
123+
private static final long serialVersionUID = 1L;
124+
125+
@Override
126+
public int[] getOldSubtaskInstances() {
127+
throw new UnsupportedOperationException(
128+
"Cannot get old subtasks from a descriptor that represents no state.");
129+
}
130+
131+
@Override
132+
public RescaleMappings getRescaleMappings() {
133+
throw new UnsupportedOperationException(
134+
"Cannot get rescale mappings from a descriptor that represents no state.");
135+
}
136+
};
137+
115138
private static final long serialVersionUID = 1L;
116139

117140
/** Set when several operator instances are merged into one. */
@@ -145,6 +168,14 @@ public InflightDataGateOrPartitionRescalingDescriptor(
145168
this.mappingType = mappingType;
146169
}
147170

171+
public int[] getOldSubtaskInstances() {
172+
return oldSubtaskIndexes;
173+
}
174+
175+
public RescaleMappings getRescaleMappings() {
176+
return rescaledChannelsMappings;
177+
}
178+
148179
public boolean isIdentity() {
149180
return mappingType == MappingType.IDENTITY;
150181
}

flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,9 @@ public void reDistributeResultSubpartitionStates(TaskStateAssignment assignment)
385385
// Parallelism of this vertex changed, distribute ResultSubpartitionStateHandle
386386
// according to output mapping.
387387
for (int partitionIndex = 0; partitionIndex < outputs.size(); partitionIndex++) {
388+
if (!assignment.hasInFlightDataForResultPartition(partitionIndex)) {
389+
continue;
390+
}
388391
final List<List<ResultSubpartitionStateHandle>> partitionState =
389392
outputs.size() == 1
390393
? outputOperatorState
@@ -465,6 +468,9 @@ public void reDistributeInputChannelStates(TaskStateAssignment stateAssignment)
465468
// subtask 0 recovers data from old subtask 0 + 1 and subtask 1 recovers data from old
466469
// subtask 1 + 2
467470
for (int gateIndex = 0; gateIndex < inputs.size(); gateIndex++) {
471+
if (!stateAssignment.hasInFlightDataForInputGate(gateIndex)) {
472+
continue;
473+
}
468474
final RescaleMappings mapping =
469475
stateAssignment.getInputMapping(gateIndex).getRescaleMappings();
470476

flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java

Lines changed: 68 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
import org.apache.flink.runtime.OperatorIDPair;
2121
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor;
2222
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor.MappingType;
23-
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
24-
import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
2523
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
2624
import org.apache.flink.runtime.executiongraph.IntermediateResult;
2725
import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper;
@@ -30,8 +28,6 @@
3028
import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
3129
import org.apache.flink.runtime.state.InputChannelStateHandle;
3230
import org.apache.flink.runtime.state.KeyedStateHandle;
33-
import org.apache.flink.runtime.state.MergedInputChannelStateHandle;
34-
import org.apache.flink.runtime.state.MergedResultSubpartitionStateHandle;
3531
import org.apache.flink.runtime.state.OperatorStateHandle;
3632
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
3733
import org.apache.flink.runtime.state.StateObject;
@@ -147,46 +143,15 @@ private static Set<Integer> extractInputStateGates(OperatorState operatorState)
147143
return operatorState.getStates().stream()
148144
.map(OperatorSubtaskState::getInputChannelState)
149145
.flatMap(Collection::stream)
150-
.flatMapToInt(
151-
handle -> {
152-
if (handle instanceof InputChannelStateHandle) {
153-
return IntStream.of(
154-
((InputChannelStateHandle) handle).getInfo().getGateIdx());
155-
} else if (handle instanceof MergedInputChannelStateHandle) {
156-
return ((MergedInputChannelStateHandle) handle)
157-
.getInfos().stream().mapToInt(InputChannelInfo::getGateIdx);
158-
} else {
159-
throw new IllegalStateException(
160-
"Invalid input channel state : " + handle.getClass());
161-
}
162-
})
163-
.distinct()
164-
.boxed()
146+
.map(handle -> handle.getInfo().getGateIdx())
165147
.collect(Collectors.toSet());
166148
}
167149

168150
private static Set<Integer> extractOutputStatePartitions(OperatorState operatorState) {
169151
return operatorState.getStates().stream()
170152
.map(OperatorSubtaskState::getResultSubpartitionState)
171153
.flatMap(Collection::stream)
172-
.flatMapToInt(
173-
handle -> {
174-
if (handle instanceof ResultSubpartitionStateHandle) {
175-
return IntStream.of(
176-
((ResultSubpartitionStateHandle) handle)
177-
.getInfo()
178-
.getPartitionIdx());
179-
} else if (handle instanceof MergedResultSubpartitionStateHandle) {
180-
return ((MergedResultSubpartitionStateHandle) handle)
181-
.getInfos().stream()
182-
.mapToInt(ResultSubpartitionInfo::getPartitionIdx);
183-
} else {
184-
throw new IllegalStateException(
185-
"Invalid output channel state : " + handle.getClass());
186-
}
187-
})
188-
.distinct()
189-
.boxed()
154+
.map(handle -> handle.getInfo().getPartitionIdx())
190155
.collect(Collectors.toSet());
191156
}
192157

@@ -252,7 +217,8 @@ public OperatorSubtaskState getSubtaskState(OperatorInstanceID instanceID) {
252217
return assignment.getOutputMapping(assignmentIndex, recompute);
253218
},
254219
inputSubtaskMappings,
255-
this::getInputMapping))
220+
this::getInputMapping,
221+
true))
256222
.setOutputRescalingDescriptor(
257223
createRescalingDescriptor(
258224
instanceID,
@@ -265,7 +231,8 @@ public OperatorSubtaskState getSubtaskState(OperatorInstanceID instanceID) {
265231
return assignment.getInputMapping(assignmentIndex, recompute);
266232
},
267233
outputSubtaskMappings,
268-
this::getOutputMapping))
234+
this::getOutputMapping,
235+
false))
269236
.build();
270237
}
271238

@@ -314,7 +281,8 @@ private InflightDataRescalingDescriptor createRescalingDescriptor(
314281
TaskStateAssignment[] connectedAssignments,
315282
BiFunction<TaskStateAssignment, Boolean, SubtasksRescaleMapping> mappingRetriever,
316283
Map<Integer, SubtasksRescaleMapping> subtaskGateOrPartitionMappings,
317-
Function<Integer, SubtasksRescaleMapping> subtaskMappingCalculator) {
284+
Function<Integer, SubtasksRescaleMapping> subtaskMappingCalculator,
285+
boolean isInput) {
318286
if (!expectedOperatorID.equals(instanceID.getOperatorId())) {
319287
return InflightDataRescalingDescriptor.NO_RESCALE;
320288
}
@@ -337,7 +305,8 @@ private InflightDataRescalingDescriptor createRescalingDescriptor(
337305
assignment -> mappingRetriever.apply(assignment, true),
338306
subtaskGateOrPartitionMappings,
339307
subtaskMappingCalculator,
340-
rescaledChannelsMappings);
308+
rescaledChannelsMappings,
309+
isInput);
341310

342311
if (Arrays.stream(gateOrPartitionDescriptors)
343312
.allMatch(InflightDataGateOrPartitionRescalingDescriptor::isIdentity)) {
@@ -356,10 +325,14 @@ private InflightDataRescalingDescriptor createRescalingDescriptor(
356325
Function<TaskStateAssignment, SubtasksRescaleMapping> mappingCalculator,
357326
Map<Integer, SubtasksRescaleMapping> subtaskGateOrPartitionMappings,
358327
Function<Integer, SubtasksRescaleMapping> subtaskMappingCalculator,
359-
SubtasksRescaleMapping[] rescaledChannelsMappings) {
328+
SubtasksRescaleMapping[] rescaledChannelsMappings,
329+
boolean isInput) {
360330
return IntStream.range(0, rescaledChannelsMappings.length)
361331
.mapToObj(
362332
partition -> {
333+
if (!hasInFlightData(isInput, partition)) {
334+
return InflightDataGateOrPartitionRescalingDescriptor.NO_STATE;
335+
}
363336
TaskStateAssignment connectedAssignment =
364337
connectedAssignments[partition];
365338
SubtasksRescaleMapping rescaleMapping =
@@ -381,6 +354,14 @@ private InflightDataRescalingDescriptor createRescalingDescriptor(
381354
.toArray(InflightDataGateOrPartitionRescalingDescriptor[]::new);
382355
}
383356

357+
private boolean hasInFlightData(boolean isInput, int gateOrPartitionIndex) {
358+
if (isInput) {
359+
return hasInFlightDataForInputGate(gateOrPartitionIndex);
360+
} else {
361+
return hasInFlightDataForResultPartition(gateOrPartitionIndex);
362+
}
363+
}
364+
384365
private InflightDataGateOrPartitionRescalingDescriptor
385366
getInflightDataGateOrPartitionRescalingDescriptor(
386367
OperatorInstanceID instanceID,
@@ -479,6 +460,51 @@ public SubtasksRescaleMapping getInputMapping(int gateIndex) {
479460
checkSubtaskMapping(oldMapping, mapping, mapper.isAmbiguous()));
480461
}
481462

463+
public boolean hasInFlightDataForInputGate(int gateIndex) {
464+
// Check own input state for this gate
465+
if (inputStateGates.contains(gateIndex)) {
466+
return true;
467+
}
468+
469+
// Check upstream output state for this gate
470+
TaskStateAssignment upstreamAssignment = getUpstreamAssignments()[gateIndex];
471+
if (upstreamAssignment != null && upstreamAssignment.hasOutputState()) {
472+
IntermediateResult inputResult = executionJobVertex.getInputs().get(gateIndex);
473+
IntermediateDataSetID resultId = inputResult.getId();
474+
IntermediateResult[] producedDataSets = inputResult.getProducer().getProducedDataSets();
475+
for (int i = 0; i < producedDataSets.length; i++) {
476+
if (producedDataSets[i].getId().equals(resultId)) {
477+
return upstreamAssignment.outputStatePartitions.contains(i);
478+
}
479+
}
480+
}
481+
482+
return false;
483+
}
484+
485+
public boolean hasInFlightDataForResultPartition(int partitionIndex) {
486+
// Check own output state for this partition
487+
if (outputStatePartitions.contains(partitionIndex)) {
488+
return true;
489+
}
490+
491+
// Check downstream input state for this partition
492+
TaskStateAssignment downstreamAssignment = getDownstreamAssignments()[partitionIndex];
493+
494+
if (downstreamAssignment != null && downstreamAssignment.hasInputState()) {
495+
IntermediateResult producedResult =
496+
executionJobVertex.getProducedDataSets()[partitionIndex];
497+
IntermediateDataSetID resultId = producedResult.getId();
498+
List<IntermediateResult> inputs = downstreamAssignment.executionJobVertex.getInputs();
499+
for (int i = 0; i < inputs.size(); i++) {
500+
if (inputs.get(i).getId().equals(resultId)) {
501+
return downstreamAssignment.inputStateGates.contains(i);
502+
}
503+
}
504+
}
505+
return false;
506+
}
507+
482508
@Override
483509
public String toString() {
484510
return "TaskStateAssignment for " + executionJobVertex.getName();
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.runtime.checkpoint;
20+
21+
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor;
22+
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor.MappingType;
23+
24+
import org.junit.jupiter.api.Test;
25+
26+
import java.util.Arrays;
27+
import java.util.Collections;
28+
29+
import static org.assertj.core.api.Assertions.assertThat;
30+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
31+
32+
/** Tests for {@link InflightDataRescalingDescriptor}. */
33+
class InflightDataRescalingDescriptorTest {
34+
35+
@Test
36+
void testNoStateDescriptorThrowsOnGetOldSubtaskInstances() {
37+
InflightDataGateOrPartitionRescalingDescriptor noStateDescriptor =
38+
InflightDataGateOrPartitionRescalingDescriptor.NO_STATE;
39+
40+
assertThatThrownBy(noStateDescriptor::getOldSubtaskInstances)
41+
.isInstanceOf(UnsupportedOperationException.class)
42+
.hasMessageContaining(
43+
"Cannot get old subtasks from a descriptor that represents no state");
44+
}
45+
46+
@Test
47+
void testNoStateDescriptorThrowsOnGetRescaleMappings() {
48+
InflightDataGateOrPartitionRescalingDescriptor noStateDescriptor =
49+
InflightDataGateOrPartitionRescalingDescriptor.NO_STATE;
50+
51+
assertThatThrownBy(noStateDescriptor::getRescaleMappings)
52+
.isInstanceOf(UnsupportedOperationException.class)
53+
.hasMessageContaining(
54+
"Cannot get rescale mappings from a descriptor that represents no state");
55+
}
56+
57+
@Test
58+
void testNoStateDescriptorIsIdentity() {
59+
InflightDataGateOrPartitionRescalingDescriptor noStateDescriptor =
60+
InflightDataGateOrPartitionRescalingDescriptor.NO_STATE;
61+
62+
assertThat(noStateDescriptor.isIdentity()).isTrue();
63+
}
64+
65+
@Test
66+
void testRegularDescriptorDoesNotThrow() {
67+
int[] oldSubtasks = new int[] {0, 1, 2};
68+
RescaleMappings mappings =
69+
RescaleMappings.of(Arrays.stream(new int[][] {{0}, {1}, {2}}), 3);
70+
71+
InflightDataGateOrPartitionRescalingDescriptor descriptor =
72+
new InflightDataGateOrPartitionRescalingDescriptor(
73+
oldSubtasks, mappings, Collections.emptySet(), MappingType.RESCALING);
74+
75+
// Should not throw
76+
assertThat(descriptor.getOldSubtaskInstances()).isEqualTo(oldSubtasks);
77+
assertThat(descriptor.getRescaleMappings()).isEqualTo(mappings);
78+
assertThat(descriptor.isIdentity()).isFalse();
79+
}
80+
81+
@Test
82+
void testIdentityDescriptor() {
83+
int[] oldSubtasks = new int[] {0};
84+
RescaleMappings mappings = RescaleMappings.identity(1, 1);
85+
86+
InflightDataGateOrPartitionRescalingDescriptor descriptor =
87+
new InflightDataGateOrPartitionRescalingDescriptor(
88+
oldSubtasks, mappings, Collections.emptySet(), MappingType.IDENTITY);
89+
90+
assertThat(descriptor.isIdentity()).isTrue();
91+
assertThat(descriptor.getOldSubtaskInstances()).isEqualTo(oldSubtasks);
92+
assertThat(descriptor.getRescaleMappings()).isEqualTo(mappings);
93+
}
94+
95+
@Test
96+
void testInflightDataRescalingDescriptorWithNoStateDescriptor() {
97+
// Create a descriptor array with NO_STATE descriptor
98+
InflightDataGateOrPartitionRescalingDescriptor[] descriptors =
99+
new InflightDataGateOrPartitionRescalingDescriptor[] {
100+
InflightDataGateOrPartitionRescalingDescriptor.NO_STATE,
101+
new InflightDataGateOrPartitionRescalingDescriptor(
102+
new int[] {0, 1},
103+
RescaleMappings.of(Arrays.stream(new int[][] {{0}, {1}}), 2),
104+
Collections.emptySet(),
105+
MappingType.RESCALING)
106+
};
107+
108+
InflightDataRescalingDescriptor rescalingDescriptor =
109+
new InflightDataRescalingDescriptor(descriptors);
110+
111+
// First gate/partition has NO_STATE
112+
assertThatThrownBy(() -> rescalingDescriptor.getOldSubtaskIndexes(0))
113+
.isInstanceOf(UnsupportedOperationException.class);
114+
assertThatThrownBy(() -> rescalingDescriptor.getChannelMapping(0))
115+
.isInstanceOf(UnsupportedOperationException.class);
116+
117+
// Second gate/partition has normal state
118+
assertThat(rescalingDescriptor.getOldSubtaskIndexes(1)).isEqualTo(new int[] {0, 1});
119+
assertThat(rescalingDescriptor.getChannelMapping(1)).isNotNull();
120+
}
121+
}

0 commit comments

Comments
 (0)