Skip to content

Commit 37a74f2

Browse files
committed
[FLINK-38267][checkpoint] Only call channel state rescaling logic for exchange with channel state to avoid UnsupportedOperationException
1 parent 5e91570 commit 37a74f2

File tree

8 files changed

+923
-47
lines changed

8 files changed

+923
-47
lines changed

flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/InflightDataRescalingDescriptor.java

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.io.ObjectStreamException;
2121
import java.io.Serializable;
2222
import java.util.Arrays;
23+
import java.util.Collections;
2324
import java.util.Objects;
2425
import java.util.Set;
2526

@@ -45,11 +46,11 @@ public InflightDataRescalingDescriptor(
4546
}
4647

4748
public int[] getOldSubtaskIndexes(int gateOrPartitionIndex) {
48-
return gateOrPartitionDescriptors[gateOrPartitionIndex].oldSubtaskIndexes;
49+
return gateOrPartitionDescriptors[gateOrPartitionIndex].getOldSubtaskInstances();
4950
}
5051

5152
public RescaleMappings getChannelMapping(int gateOrPartitionIndex) {
52-
return gateOrPartitionDescriptors[gateOrPartitionIndex].rescaledChannelsMappings;
53+
return gateOrPartitionDescriptors[gateOrPartitionIndex].getRescaleMappings();
5354
}
5455

5556
public boolean isAmbiguous(int gateOrPartitionIndex, int oldSubtaskIndex) {
@@ -112,6 +113,28 @@ public String toString() {
112113
*/
113114
public static class InflightDataGateOrPartitionRescalingDescriptor implements Serializable {
114115

116+
public static final InflightDataGateOrPartitionRescalingDescriptor NO_STATE =
117+
new InflightDataGateOrPartitionRescalingDescriptor(
118+
new int[0],
119+
RescaleMappings.identity(0, 0),
120+
Collections.emptySet(),
121+
MappingType.IDENTITY) {
122+
123+
private static final long serialVersionUID = 1L;
124+
125+
@Override
126+
public int[] getOldSubtaskInstances() {
127+
throw new UnsupportedOperationException(
128+
"Cannot get old subtasks from a descriptor that represents no state.");
129+
}
130+
131+
@Override
132+
public RescaleMappings getRescaleMappings() {
133+
throw new UnsupportedOperationException(
134+
"Cannot get rescale mappings from a descriptor that represents no state.");
135+
}
136+
};
137+
115138
private static final long serialVersionUID = 1L;
116139

117140
/** Set when several operator instances are merged into one. */
@@ -145,6 +168,14 @@ public InflightDataGateOrPartitionRescalingDescriptor(
145168
this.mappingType = mappingType;
146169
}
147170

171+
public int[] getOldSubtaskInstances() {
172+
return oldSubtaskIndexes;
173+
}
174+
175+
public RescaleMappings getRescaleMappings() {
176+
return rescaledChannelsMappings;
177+
}
178+
148179
public boolean isIdentity() {
149180
return mappingType == MappingType.IDENTITY;
150181
}

flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,9 @@ public void reDistributeResultSubpartitionStates(TaskStateAssignment assignment)
382382
// Parallelism of this vertex changed, distribute ResultSubpartitionStateHandle
383383
// according to output mapping.
384384
for (int partitionIndex = 0; partitionIndex < outputs.size(); partitionIndex++) {
385+
if (!assignment.hasInFlightDataForResultPartition(partitionIndex)) {
386+
continue;
387+
}
385388
final List<List<ResultSubpartitionStateHandle>> partitionState =
386389
outputs.size() == 1
387390
? outputOperatorState
@@ -462,6 +465,9 @@ public void reDistributeInputChannelStates(TaskStateAssignment stateAssignment)
462465
// subtask 0 recovers data from old subtask 0 + 1 and subtask 1 recovers data from old
463466
// subtask 1 + 2
464467
for (int gateIndex = 0; gateIndex < inputs.size(); gateIndex++) {
468+
if (!stateAssignment.hasInFlightDataForInputGate(gateIndex)) {
469+
continue;
470+
}
465471
final RescaleMappings mapping =
466472
stateAssignment.getInputMapping(gateIndex).getRescaleMappings();
467473

flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java

Lines changed: 68 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
import org.apache.flink.runtime.OperatorIDPair;
2121
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor;
2222
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor.MappingType;
23-
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
24-
import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
2523
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
2624
import org.apache.flink.runtime.executiongraph.IntermediateResult;
2725
import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper;
@@ -30,8 +28,6 @@
3028
import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
3129
import org.apache.flink.runtime.state.InputChannelStateHandle;
3230
import org.apache.flink.runtime.state.KeyedStateHandle;
33-
import org.apache.flink.runtime.state.MergedInputChannelStateHandle;
34-
import org.apache.flink.runtime.state.MergedResultSubpartitionStateHandle;
3531
import org.apache.flink.runtime.state.OperatorStateHandle;
3632
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
3733
import org.apache.flink.runtime.state.StateObject;
@@ -145,46 +141,15 @@ private static Set<Integer> extractInputStateGates(OperatorState operatorState)
145141
return operatorState.getStates().stream()
146142
.map(OperatorSubtaskState::getInputChannelState)
147143
.flatMap(Collection::stream)
148-
.flatMapToInt(
149-
handle -> {
150-
if (handle instanceof InputChannelStateHandle) {
151-
return IntStream.of(
152-
((InputChannelStateHandle) handle).getInfo().getGateIdx());
153-
} else if (handle instanceof MergedInputChannelStateHandle) {
154-
return ((MergedInputChannelStateHandle) handle)
155-
.getInfos().stream().mapToInt(InputChannelInfo::getGateIdx);
156-
} else {
157-
throw new IllegalStateException(
158-
"Invalid input channel state : " + handle.getClass());
159-
}
160-
})
161-
.distinct()
162-
.boxed()
144+
.map(handle -> handle.getInfo().getGateIdx())
163145
.collect(Collectors.toSet());
164146
}
165147

166148
private static Set<Integer> extractOutputStatePartitions(OperatorState operatorState) {
167149
return operatorState.getStates().stream()
168150
.map(OperatorSubtaskState::getResultSubpartitionState)
169151
.flatMap(Collection::stream)
170-
.flatMapToInt(
171-
handle -> {
172-
if (handle instanceof ResultSubpartitionStateHandle) {
173-
return IntStream.of(
174-
((ResultSubpartitionStateHandle) handle)
175-
.getInfo()
176-
.getPartitionIdx());
177-
} else if (handle instanceof MergedResultSubpartitionStateHandle) {
178-
return ((MergedResultSubpartitionStateHandle) handle)
179-
.getInfos().stream()
180-
.mapToInt(ResultSubpartitionInfo::getPartitionIdx);
181-
} else {
182-
throw new IllegalStateException(
183-
"Invalid output channel state : " + handle.getClass());
184-
}
185-
})
186-
.distinct()
187-
.boxed()
152+
.map(handle -> handle.getInfo().getPartitionIdx())
188153
.collect(Collectors.toSet());
189154
}
190155

@@ -250,7 +215,8 @@ public OperatorSubtaskState getSubtaskState(OperatorInstanceID instanceID) {
250215
return assignment.getOutputMapping(assignmentIndex, recompute);
251216
},
252217
inputSubtaskMappings,
253-
this::getInputMapping))
218+
this::getInputMapping,
219+
true))
254220
.setOutputRescalingDescriptor(
255221
createRescalingDescriptor(
256222
instanceID,
@@ -263,7 +229,8 @@ public OperatorSubtaskState getSubtaskState(OperatorInstanceID instanceID) {
263229
return assignment.getInputMapping(assignmentIndex, recompute);
264230
},
265231
outputSubtaskMappings,
266-
this::getOutputMapping))
232+
this::getOutputMapping,
233+
false))
267234
.build();
268235
}
269236

@@ -312,7 +279,8 @@ private InflightDataRescalingDescriptor createRescalingDescriptor(
312279
TaskStateAssignment[] connectedAssignments,
313280
BiFunction<TaskStateAssignment, Boolean, SubtasksRescaleMapping> mappingRetriever,
314281
Map<Integer, SubtasksRescaleMapping> subtaskGateOrPartitionMappings,
315-
Function<Integer, SubtasksRescaleMapping> subtaskMappingCalculator) {
282+
Function<Integer, SubtasksRescaleMapping> subtaskMappingCalculator,
283+
boolean isInput) {
316284
if (!expectedOperatorID.equals(instanceID.getOperatorId())) {
317285
return InflightDataRescalingDescriptor.NO_RESCALE;
318286
}
@@ -335,7 +303,8 @@ private InflightDataRescalingDescriptor createRescalingDescriptor(
335303
assignment -> mappingRetriever.apply(assignment, true),
336304
subtaskGateOrPartitionMappings,
337305
subtaskMappingCalculator,
338-
rescaledChannelsMappings);
306+
rescaledChannelsMappings,
307+
isInput);
339308

340309
if (Arrays.stream(gateOrPartitionDescriptors)
341310
.allMatch(InflightDataGateOrPartitionRescalingDescriptor::isIdentity)) {
@@ -354,10 +323,14 @@ private InflightDataRescalingDescriptor createRescalingDescriptor(
354323
Function<TaskStateAssignment, SubtasksRescaleMapping> mappingCalculator,
355324
Map<Integer, SubtasksRescaleMapping> subtaskGateOrPartitionMappings,
356325
Function<Integer, SubtasksRescaleMapping> subtaskMappingCalculator,
357-
SubtasksRescaleMapping[] rescaledChannelsMappings) {
326+
SubtasksRescaleMapping[] rescaledChannelsMappings,
327+
boolean isInput) {
358328
return IntStream.range(0, rescaledChannelsMappings.length)
359329
.mapToObj(
360330
partition -> {
331+
if (!hasInFlightData(isInput, partition)) {
332+
return InflightDataGateOrPartitionRescalingDescriptor.NO_STATE;
333+
}
361334
TaskStateAssignment connectedAssignment =
362335
connectedAssignments[partition];
363336
SubtasksRescaleMapping rescaleMapping =
@@ -379,6 +352,14 @@ private InflightDataRescalingDescriptor createRescalingDescriptor(
379352
.toArray(InflightDataGateOrPartitionRescalingDescriptor[]::new);
380353
}
381354

355+
private boolean hasInFlightData(boolean isInput, int gateOrPartitionIndex) {
356+
if (isInput) {
357+
return hasInFlightDataForInputGate(gateOrPartitionIndex);
358+
} else {
359+
return hasInFlightDataForResultPartition(gateOrPartitionIndex);
360+
}
361+
}
362+
382363
private InflightDataGateOrPartitionRescalingDescriptor
383364
getInflightDataGateOrPartitionRescalingDescriptor(
384365
OperatorInstanceID instanceID,
@@ -477,6 +458,51 @@ public SubtasksRescaleMapping getInputMapping(int gateIndex) {
477458
checkSubtaskMapping(oldMapping, mapping, mapper.isAmbiguous()));
478459
}
479460

461+
public boolean hasInFlightDataForInputGate(int gateIndex) {
462+
// Check own input state for this gate
463+
if (inputStateGates.contains(gateIndex)) {
464+
return true;
465+
}
466+
467+
// Check upstream output state for this gate
468+
TaskStateAssignment upstreamAssignment = getUpstreamAssignments()[gateIndex];
469+
if (upstreamAssignment != null && upstreamAssignment.hasOutputState()) {
470+
IntermediateResult inputResult = executionJobVertex.getInputs().get(gateIndex);
471+
IntermediateDataSetID resultId = inputResult.getId();
472+
IntermediateResult[] producedDataSets = inputResult.getProducer().getProducedDataSets();
473+
for (int i = 0; i < producedDataSets.length; i++) {
474+
if (producedDataSets[i].getId().equals(resultId)) {
475+
return upstreamAssignment.outputStatePartitions.contains(i);
476+
}
477+
}
478+
}
479+
480+
return false;
481+
}
482+
483+
public boolean hasInFlightDataForResultPartition(int partitionIndex) {
484+
// Check own output state for this partition
485+
if (outputStatePartitions.contains(partitionIndex)) {
486+
return true;
487+
}
488+
489+
// Check downstream input state for this partition
490+
TaskStateAssignment downstreamAssignment = getDownstreamAssignments()[partitionIndex];
491+
492+
if (downstreamAssignment != null && downstreamAssignment.hasInputState()) {
493+
IntermediateResult producedResult =
494+
executionJobVertex.getProducedDataSets()[partitionIndex];
495+
IntermediateDataSetID resultId = producedResult.getId();
496+
List<IntermediateResult> inputs = downstreamAssignment.executionJobVertex.getInputs();
497+
for (int i = 0; i < inputs.size(); i++) {
498+
if (inputs.get(i).getId().equals(resultId)) {
499+
return downstreamAssignment.inputStateGates.contains(i);
500+
}
501+
}
502+
}
503+
return false;
504+
}
505+
480506
@Override
481507
public String toString() {
482508
return "TaskStateAssignment for " + executionJobVertex.getName();
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.runtime.checkpoint;
20+
21+
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor;
22+
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor.MappingType;
23+
24+
import org.junit.jupiter.api.Test;
25+
26+
import java.util.Arrays;
27+
import java.util.Collections;
28+
29+
import static org.assertj.core.api.Assertions.assertThat;
30+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
31+
32+
/** Tests for {@link InflightDataRescalingDescriptor}. */
33+
class InflightDataRescalingDescriptorTest {
34+
35+
@Test
36+
void testNoStateDescriptorThrowsOnGetOldSubtaskInstances() {
37+
InflightDataGateOrPartitionRescalingDescriptor noStateDescriptor =
38+
InflightDataGateOrPartitionRescalingDescriptor.NO_STATE;
39+
40+
assertThatThrownBy(noStateDescriptor::getOldSubtaskInstances)
41+
.isInstanceOf(UnsupportedOperationException.class)
42+
.hasMessageContaining(
43+
"Cannot get old subtasks from a descriptor that represents no state");
44+
}
45+
46+
@Test
47+
void testNoStateDescriptorThrowsOnGetRescaleMappings() {
48+
InflightDataGateOrPartitionRescalingDescriptor noStateDescriptor =
49+
InflightDataGateOrPartitionRescalingDescriptor.NO_STATE;
50+
51+
assertThatThrownBy(noStateDescriptor::getRescaleMappings)
52+
.isInstanceOf(UnsupportedOperationException.class)
53+
.hasMessageContaining(
54+
"Cannot get rescale mappings from a descriptor that represents no state");
55+
}
56+
57+
@Test
58+
void testNoStateDescriptorIsIdentity() {
59+
InflightDataGateOrPartitionRescalingDescriptor noStateDescriptor =
60+
InflightDataGateOrPartitionRescalingDescriptor.NO_STATE;
61+
62+
assertThat(noStateDescriptor.isIdentity()).isTrue();
63+
}
64+
65+
@Test
66+
void testRegularDescriptorDoesNotThrow() {
67+
int[] oldSubtasks = new int[] {0, 1, 2};
68+
RescaleMappings mappings =
69+
RescaleMappings.of(Arrays.stream(new int[][] {{0}, {1}, {2}}), 3);
70+
71+
InflightDataGateOrPartitionRescalingDescriptor descriptor =
72+
new InflightDataGateOrPartitionRescalingDescriptor(
73+
oldSubtasks, mappings, Collections.emptySet(), MappingType.RESCALING);
74+
75+
// Should not throw
76+
assertThat(descriptor.getOldSubtaskInstances()).isEqualTo(oldSubtasks);
77+
assertThat(descriptor.getRescaleMappings()).isEqualTo(mappings);
78+
assertThat(descriptor.isIdentity()).isFalse();
79+
}
80+
81+
@Test
82+
void testIdentityDescriptor() {
83+
int[] oldSubtasks = new int[] {0};
84+
RescaleMappings mappings = RescaleMappings.identity(1, 1);
85+
86+
InflightDataGateOrPartitionRescalingDescriptor descriptor =
87+
new InflightDataGateOrPartitionRescalingDescriptor(
88+
oldSubtasks, mappings, Collections.emptySet(), MappingType.IDENTITY);
89+
90+
assertThat(descriptor.isIdentity()).isTrue();
91+
assertThat(descriptor.getOldSubtaskInstances()).isEqualTo(oldSubtasks);
92+
assertThat(descriptor.getRescaleMappings()).isEqualTo(mappings);
93+
}
94+
95+
@Test
96+
void testInflightDataRescalingDescriptorWithNoStateDescriptor() {
97+
// Create a descriptor array with NO_STATE descriptor
98+
InflightDataGateOrPartitionRescalingDescriptor[] descriptors =
99+
new InflightDataGateOrPartitionRescalingDescriptor[] {
100+
InflightDataGateOrPartitionRescalingDescriptor.NO_STATE,
101+
new InflightDataGateOrPartitionRescalingDescriptor(
102+
new int[] {0, 1},
103+
RescaleMappings.of(Arrays.stream(new int[][] {{0}, {1}}), 2),
104+
Collections.emptySet(),
105+
MappingType.RESCALING)
106+
};
107+
108+
InflightDataRescalingDescriptor rescalingDescriptor =
109+
new InflightDataRescalingDescriptor(descriptors);
110+
111+
// First gate/partition has NO_STATE
112+
assertThatThrownBy(() -> rescalingDescriptor.getOldSubtaskIndexes(0))
113+
.isInstanceOf(UnsupportedOperationException.class);
114+
assertThatThrownBy(() -> rescalingDescriptor.getChannelMapping(0))
115+
.isInstanceOf(UnsupportedOperationException.class);
116+
117+
// Second gate/partition has normal state
118+
assertThat(rescalingDescriptor.getOldSubtaskIndexes(1)).isEqualTo(new int[] {0, 1});
119+
assertThat(rescalingDescriptor.getChannelMapping(1)).isNotNull();
120+
}
121+
}

0 commit comments

Comments
 (0)