20
20
import org .apache .flink .runtime .OperatorIDPair ;
21
21
import org .apache .flink .runtime .checkpoint .InflightDataRescalingDescriptor .InflightDataGateOrPartitionRescalingDescriptor ;
22
22
import org .apache .flink .runtime .checkpoint .InflightDataRescalingDescriptor .InflightDataGateOrPartitionRescalingDescriptor .MappingType ;
23
- import org .apache .flink .runtime .checkpoint .channel .InputChannelInfo ;
24
- import org .apache .flink .runtime .checkpoint .channel .ResultSubpartitionInfo ;
25
23
import org .apache .flink .runtime .executiongraph .ExecutionJobVertex ;
26
24
import org .apache .flink .runtime .executiongraph .IntermediateResult ;
27
25
import org .apache .flink .runtime .io .network .api .writer .SubtaskStateMapper ;
30
28
import org .apache .flink .runtime .jobgraph .OperatorInstanceID ;
31
29
import org .apache .flink .runtime .state .InputChannelStateHandle ;
32
30
import org .apache .flink .runtime .state .KeyedStateHandle ;
33
- import org .apache .flink .runtime .state .MergedInputChannelStateHandle ;
34
- import org .apache .flink .runtime .state .MergedResultSubpartitionStateHandle ;
35
31
import org .apache .flink .runtime .state .OperatorStateHandle ;
36
32
import org .apache .flink .runtime .state .ResultSubpartitionStateHandle ;
37
33
import org .apache .flink .runtime .state .StateObject ;
@@ -147,46 +143,15 @@ private static Set<Integer> extractInputStateGates(OperatorState operatorState)
147
143
return operatorState .getStates ().stream ()
148
144
.map (OperatorSubtaskState ::getInputChannelState )
149
145
.flatMap (Collection ::stream )
150
- .flatMapToInt (
151
- handle -> {
152
- if (handle instanceof InputChannelStateHandle ) {
153
- return IntStream .of (
154
- ((InputChannelStateHandle ) handle ).getInfo ().getGateIdx ());
155
- } else if (handle instanceof MergedInputChannelStateHandle ) {
156
- return ((MergedInputChannelStateHandle ) handle )
157
- .getInfos ().stream ().mapToInt (InputChannelInfo ::getGateIdx );
158
- } else {
159
- throw new IllegalStateException (
160
- "Invalid input channel state : " + handle .getClass ());
161
- }
162
- })
163
- .distinct ()
164
- .boxed ()
146
+ .map (handle -> handle .getInfo ().getGateIdx ())
165
147
.collect (Collectors .toSet ());
166
148
}
167
149
168
150
private static Set <Integer > extractOutputStatePartitions (OperatorState operatorState ) {
169
151
return operatorState .getStates ().stream ()
170
152
.map (OperatorSubtaskState ::getResultSubpartitionState )
171
153
.flatMap (Collection ::stream )
172
- .flatMapToInt (
173
- handle -> {
174
- if (handle instanceof ResultSubpartitionStateHandle ) {
175
- return IntStream .of (
176
- ((ResultSubpartitionStateHandle ) handle )
177
- .getInfo ()
178
- .getPartitionIdx ());
179
- } else if (handle instanceof MergedResultSubpartitionStateHandle ) {
180
- return ((MergedResultSubpartitionStateHandle ) handle )
181
- .getInfos ().stream ()
182
- .mapToInt (ResultSubpartitionInfo ::getPartitionIdx );
183
- } else {
184
- throw new IllegalStateException (
185
- "Invalid output channel state : " + handle .getClass ());
186
- }
187
- })
188
- .distinct ()
189
- .boxed ()
154
+ .map (handle -> handle .getInfo ().getPartitionIdx ())
190
155
.collect (Collectors .toSet ());
191
156
}
192
157
@@ -252,7 +217,8 @@ public OperatorSubtaskState getSubtaskState(OperatorInstanceID instanceID) {
252
217
return assignment .getOutputMapping (assignmentIndex , recompute );
253
218
},
254
219
inputSubtaskMappings ,
255
- this ::getInputMapping ))
220
+ this ::getInputMapping ,
221
+ true ))
256
222
.setOutputRescalingDescriptor (
257
223
createRescalingDescriptor (
258
224
instanceID ,
@@ -265,7 +231,8 @@ public OperatorSubtaskState getSubtaskState(OperatorInstanceID instanceID) {
265
231
return assignment .getInputMapping (assignmentIndex , recompute );
266
232
},
267
233
outputSubtaskMappings ,
268
- this ::getOutputMapping ))
234
+ this ::getOutputMapping ,
235
+ false ))
269
236
.build ();
270
237
}
271
238
@@ -314,7 +281,8 @@ private InflightDataRescalingDescriptor createRescalingDescriptor(
314
281
TaskStateAssignment [] connectedAssignments ,
315
282
BiFunction <TaskStateAssignment , Boolean , SubtasksRescaleMapping > mappingRetriever ,
316
283
Map <Integer , SubtasksRescaleMapping > subtaskGateOrPartitionMappings ,
317
- Function <Integer , SubtasksRescaleMapping > subtaskMappingCalculator ) {
284
+ Function <Integer , SubtasksRescaleMapping > subtaskMappingCalculator ,
285
+ boolean isInput ) {
318
286
if (!expectedOperatorID .equals (instanceID .getOperatorId ())) {
319
287
return InflightDataRescalingDescriptor .NO_RESCALE ;
320
288
}
@@ -337,7 +305,8 @@ private InflightDataRescalingDescriptor createRescalingDescriptor(
337
305
assignment -> mappingRetriever .apply (assignment , true ),
338
306
subtaskGateOrPartitionMappings ,
339
307
subtaskMappingCalculator ,
340
- rescaledChannelsMappings );
308
+ rescaledChannelsMappings ,
309
+ isInput );
341
310
342
311
if (Arrays .stream (gateOrPartitionDescriptors )
343
312
.allMatch (InflightDataGateOrPartitionRescalingDescriptor ::isIdentity )) {
@@ -356,10 +325,14 @@ private InflightDataRescalingDescriptor createRescalingDescriptor(
356
325
Function <TaskStateAssignment , SubtasksRescaleMapping > mappingCalculator ,
357
326
Map <Integer , SubtasksRescaleMapping > subtaskGateOrPartitionMappings ,
358
327
Function <Integer , SubtasksRescaleMapping > subtaskMappingCalculator ,
359
- SubtasksRescaleMapping [] rescaledChannelsMappings ) {
328
+ SubtasksRescaleMapping [] rescaledChannelsMappings ,
329
+ boolean isInput ) {
360
330
return IntStream .range (0 , rescaledChannelsMappings .length )
361
331
.mapToObj (
362
332
partition -> {
333
+ if (!hasInFlightData (isInput , partition )) {
334
+ return InflightDataGateOrPartitionRescalingDescriptor .NO_STATE ;
335
+ }
363
336
TaskStateAssignment connectedAssignment =
364
337
connectedAssignments [partition ];
365
338
SubtasksRescaleMapping rescaleMapping =
@@ -381,6 +354,14 @@ private InflightDataRescalingDescriptor createRescalingDescriptor(
381
354
.toArray (InflightDataGateOrPartitionRescalingDescriptor []::new );
382
355
}
383
356
357
+ private boolean hasInFlightData (boolean isInput , int gateOrPartitionIndex ) {
358
+ if (isInput ) {
359
+ return hasInFlightDataForInputGate (gateOrPartitionIndex );
360
+ } else {
361
+ return hasInFlightDataForResultPartition (gateOrPartitionIndex );
362
+ }
363
+ }
364
+
384
365
private InflightDataGateOrPartitionRescalingDescriptor
385
366
getInflightDataGateOrPartitionRescalingDescriptor (
386
367
OperatorInstanceID instanceID ,
@@ -479,6 +460,51 @@ public SubtasksRescaleMapping getInputMapping(int gateIndex) {
479
460
checkSubtaskMapping (oldMapping , mapping , mapper .isAmbiguous ()));
480
461
}
481
462
463
+ public boolean hasInFlightDataForInputGate (int gateIndex ) {
464
+ // Check own input state for this gate
465
+ if (inputStateGates .contains (gateIndex )) {
466
+ return true ;
467
+ }
468
+
469
+ // Check upstream output state for this gate
470
+ TaskStateAssignment upstreamAssignment = getUpstreamAssignments ()[gateIndex ];
471
+ if (upstreamAssignment != null && upstreamAssignment .hasOutputState ()) {
472
+ IntermediateResult inputResult = executionJobVertex .getInputs ().get (gateIndex );
473
+ IntermediateDataSetID resultId = inputResult .getId ();
474
+ IntermediateResult [] producedDataSets = inputResult .getProducer ().getProducedDataSets ();
475
+ for (int i = 0 ; i < producedDataSets .length ; i ++) {
476
+ if (producedDataSets [i ].getId ().equals (resultId )) {
477
+ return upstreamAssignment .outputStatePartitions .contains (i );
478
+ }
479
+ }
480
+ }
481
+
482
+ return false ;
483
+ }
484
+
485
+ public boolean hasInFlightDataForResultPartition (int partitionIndex ) {
486
+ // Check own output state for this partition
487
+ if (outputStatePartitions .contains (partitionIndex )) {
488
+ return true ;
489
+ }
490
+
491
+ // Check downstream input state for this partition
492
+ TaskStateAssignment downstreamAssignment = getDownstreamAssignments ()[partitionIndex ];
493
+
494
+ if (downstreamAssignment != null && downstreamAssignment .hasInputState ()) {
495
+ IntermediateResult producedResult =
496
+ executionJobVertex .getProducedDataSets ()[partitionIndex ];
497
+ IntermediateDataSetID resultId = producedResult .getId ();
498
+ List <IntermediateResult > inputs = downstreamAssignment .executionJobVertex .getInputs ();
499
+ for (int i = 0 ; i < inputs .size (); i ++) {
500
+ if (inputs .get (i ).getId ().equals (resultId )) {
501
+ return downstreamAssignment .inputStateGates .contains (i );
502
+ }
503
+ }
504
+ }
505
+ return false ;
506
+ }
507
+
482
508
@ Override
483
509
public String toString () {
484
510
return "TaskStateAssignment for " + executionJobVertex .getName ();
0 commit comments