2020import org .apache .flink .runtime .OperatorIDPair ;
2121import org .apache .flink .runtime .checkpoint .InflightDataRescalingDescriptor .InflightDataGateOrPartitionRescalingDescriptor ;
2222import org .apache .flink .runtime .checkpoint .InflightDataRescalingDescriptor .InflightDataGateOrPartitionRescalingDescriptor .MappingType ;
23- import org .apache .flink .runtime .checkpoint .channel .InputChannelInfo ;
24- import org .apache .flink .runtime .checkpoint .channel .ResultSubpartitionInfo ;
2523import org .apache .flink .runtime .executiongraph .ExecutionJobVertex ;
2624import org .apache .flink .runtime .executiongraph .IntermediateResult ;
2725import org .apache .flink .runtime .io .network .api .writer .SubtaskStateMapper ;
3028import org .apache .flink .runtime .jobgraph .OperatorInstanceID ;
3129import org .apache .flink .runtime .state .InputChannelStateHandle ;
3230import org .apache .flink .runtime .state .KeyedStateHandle ;
33- import org .apache .flink .runtime .state .MergedInputChannelStateHandle ;
34- import org .apache .flink .runtime .state .MergedResultSubpartitionStateHandle ;
3531import org .apache .flink .runtime .state .OperatorStateHandle ;
3632import org .apache .flink .runtime .state .ResultSubpartitionStateHandle ;
3733import org .apache .flink .runtime .state .StateObject ;
@@ -145,46 +141,15 @@ private static Set<Integer> extractInputStateGates(OperatorState operatorState)
145141 return operatorState .getStates ().stream ()
146142 .map (OperatorSubtaskState ::getInputChannelState )
147143 .flatMap (Collection ::stream )
148- .flatMapToInt (
149- handle -> {
150- if (handle instanceof InputChannelStateHandle ) {
151- return IntStream .of (
152- ((InputChannelStateHandle ) handle ).getInfo ().getGateIdx ());
153- } else if (handle instanceof MergedInputChannelStateHandle ) {
154- return ((MergedInputChannelStateHandle ) handle )
155- .getInfos ().stream ().mapToInt (InputChannelInfo ::getGateIdx );
156- } else {
157- throw new IllegalStateException (
158- "Invalid input channel state : " + handle .getClass ());
159- }
160- })
161- .distinct ()
162- .boxed ()
144+ .map (handle -> handle .getInfo ().getGateIdx ())
163145 .collect (Collectors .toSet ());
164146 }
165147
166148 private static Set <Integer > extractOutputStatePartitions (OperatorState operatorState ) {
167149 return operatorState .getStates ().stream ()
168150 .map (OperatorSubtaskState ::getResultSubpartitionState )
169151 .flatMap (Collection ::stream )
170- .flatMapToInt (
171- handle -> {
172- if (handle instanceof ResultSubpartitionStateHandle ) {
173- return IntStream .of (
174- ((ResultSubpartitionStateHandle ) handle )
175- .getInfo ()
176- .getPartitionIdx ());
177- } else if (handle instanceof MergedResultSubpartitionStateHandle ) {
178- return ((MergedResultSubpartitionStateHandle ) handle )
179- .getInfos ().stream ()
180- .mapToInt (ResultSubpartitionInfo ::getPartitionIdx );
181- } else {
182- throw new IllegalStateException (
183- "Invalid output channel state : " + handle .getClass ());
184- }
185- })
186- .distinct ()
187- .boxed ()
152+ .map (handle -> handle .getInfo ().getPartitionIdx ())
188153 .collect (Collectors .toSet ());
189154 }
190155
@@ -250,7 +215,8 @@ public OperatorSubtaskState getSubtaskState(OperatorInstanceID instanceID) {
250215 return assignment .getOutputMapping (assignmentIndex , recompute );
251216 },
252217 inputSubtaskMappings ,
253- this ::getInputMapping ))
218+ this ::getInputMapping ,
219+ true ))
254220 .setOutputRescalingDescriptor (
255221 createRescalingDescriptor (
256222 instanceID ,
@@ -263,7 +229,8 @@ public OperatorSubtaskState getSubtaskState(OperatorInstanceID instanceID) {
263229 return assignment .getInputMapping (assignmentIndex , recompute );
264230 },
265231 outputSubtaskMappings ,
266- this ::getOutputMapping ))
232+ this ::getOutputMapping ,
233+ false ))
267234 .build ();
268235 }
269236
@@ -312,7 +279,8 @@ private InflightDataRescalingDescriptor createRescalingDescriptor(
312279 TaskStateAssignment [] connectedAssignments ,
313280 BiFunction <TaskStateAssignment , Boolean , SubtasksRescaleMapping > mappingRetriever ,
314281 Map <Integer , SubtasksRescaleMapping > subtaskGateOrPartitionMappings ,
315- Function <Integer , SubtasksRescaleMapping > subtaskMappingCalculator ) {
282+ Function <Integer , SubtasksRescaleMapping > subtaskMappingCalculator ,
283+ boolean isInput ) {
316284 if (!expectedOperatorID .equals (instanceID .getOperatorId ())) {
317285 return InflightDataRescalingDescriptor .NO_RESCALE ;
318286 }
@@ -335,7 +303,8 @@ private InflightDataRescalingDescriptor createRescalingDescriptor(
335303 assignment -> mappingRetriever .apply (assignment , true ),
336304 subtaskGateOrPartitionMappings ,
337305 subtaskMappingCalculator ,
338- rescaledChannelsMappings );
306+ rescaledChannelsMappings ,
307+ isInput );
339308
340309 if (Arrays .stream (gateOrPartitionDescriptors )
341310 .allMatch (InflightDataGateOrPartitionRescalingDescriptor ::isIdentity )) {
@@ -354,10 +323,14 @@ private InflightDataRescalingDescriptor createRescalingDescriptor(
354323 Function <TaskStateAssignment , SubtasksRescaleMapping > mappingCalculator ,
355324 Map <Integer , SubtasksRescaleMapping > subtaskGateOrPartitionMappings ,
356325 Function <Integer , SubtasksRescaleMapping > subtaskMappingCalculator ,
357- SubtasksRescaleMapping [] rescaledChannelsMappings ) {
326+ SubtasksRescaleMapping [] rescaledChannelsMappings ,
327+ boolean isInput ) {
358328 return IntStream .range (0 , rescaledChannelsMappings .length )
359329 .mapToObj (
360330 partition -> {
331+ if (!hasInFlightData (isInput , partition )) {
332+ return InflightDataGateOrPartitionRescalingDescriptor .NO_STATE ;
333+ }
361334 TaskStateAssignment connectedAssignment =
362335 connectedAssignments [partition ];
363336 SubtasksRescaleMapping rescaleMapping =
@@ -379,6 +352,14 @@ private InflightDataRescalingDescriptor createRescalingDescriptor(
379352 .toArray (InflightDataGateOrPartitionRescalingDescriptor []::new );
380353 }
381354
355+ private boolean hasInFlightData (boolean isInput , int gateOrPartitionIndex ) {
356+ if (isInput ) {
357+ return hasInFlightDataForInputGate (gateOrPartitionIndex );
358+ } else {
359+ return hasInFlightDataForResultPartition (gateOrPartitionIndex );
360+ }
361+ }
362+
382363 private InflightDataGateOrPartitionRescalingDescriptor
383364 getInflightDataGateOrPartitionRescalingDescriptor (
384365 OperatorInstanceID instanceID ,
@@ -477,6 +458,51 @@ public SubtasksRescaleMapping getInputMapping(int gateIndex) {
477458 checkSubtaskMapping (oldMapping , mapping , mapper .isAmbiguous ()));
478459 }
479460
461+ public boolean hasInFlightDataForInputGate (int gateIndex ) {
462+ // Check own input state for this gate
463+ if (inputStateGates .contains (gateIndex )) {
464+ return true ;
465+ }
466+
467+ // Check upstream output state for this gate
468+ TaskStateAssignment upstreamAssignment = getUpstreamAssignments ()[gateIndex ];
469+ if (upstreamAssignment != null && upstreamAssignment .hasOutputState ()) {
470+ IntermediateResult inputResult = executionJobVertex .getInputs ().get (gateIndex );
471+ IntermediateDataSetID resultId = inputResult .getId ();
472+ IntermediateResult [] producedDataSets = inputResult .getProducer ().getProducedDataSets ();
473+ for (int i = 0 ; i < producedDataSets .length ; i ++) {
474+ if (producedDataSets [i ].getId ().equals (resultId )) {
475+ return upstreamAssignment .outputStatePartitions .contains (i );
476+ }
477+ }
478+ }
479+
480+ return false ;
481+ }
482+
483+ public boolean hasInFlightDataForResultPartition (int partitionIndex ) {
484+ // Check own output state for this partition
485+ if (outputStatePartitions .contains (partitionIndex )) {
486+ return true ;
487+ }
488+
489+ // Check downstream input state for this partition
490+ TaskStateAssignment downstreamAssignment = getDownstreamAssignments ()[partitionIndex ];
491+
492+ if (downstreamAssignment != null && downstreamAssignment .hasInputState ()) {
493+ IntermediateResult producedResult =
494+ executionJobVertex .getProducedDataSets ()[partitionIndex ];
495+ IntermediateDataSetID resultId = producedResult .getId ();
496+ List <IntermediateResult > inputs = downstreamAssignment .executionJobVertex .getInputs ();
497+ for (int i = 0 ; i < inputs .size (); i ++) {
498+ if (inputs .get (i ).getId ().equals (resultId )) {
499+ return downstreamAssignment .inputStateGates .contains (i );
500+ }
501+ }
502+ }
503+ return false ;
504+ }
505+
480506 @ Override
481507 public String toString () {
482508 return "TaskStateAssignment for " + executionJobVertex .getName ();
0 commit comments