@@ -23,6 +23,7 @@ import org.jetbrains.kotlin.fir.resolve.fullyExpandedType
23
23
import org.jetbrains.kotlin.fir.resolve.providers.symbolProvider
24
24
import org.jetbrains.kotlin.fir.resolve.toRegularClassSymbol
25
25
import org.jetbrains.kotlin.fir.resolve.toSymbol
26
+ import org.jetbrains.kotlin.fir.resolve.transformers.WhenOnSealedClassExhaustivenessChecker.ConditionChecker.processBranch
26
27
import org.jetbrains.kotlin.fir.symbols.FirBasedSymbol
27
28
import org.jetbrains.kotlin.fir.symbols.impl.FirClassSymbol
28
29
import org.jetbrains.kotlin.fir.symbols.impl.FirRegularClassSymbol
@@ -332,6 +333,12 @@ private object WhenOnBooleanExhaustivenessChecker : WhenExhaustivenessChecker()
332
333
var containsFalse = false
333
334
}
334
335
336
+ private fun recordValue (value : Any? , data : Flags ) = when (value) {
337
+ true -> data.containsTrue = true
338
+ false -> data.containsFalse = true
339
+ else -> {}
340
+ }
341
+
335
342
override fun computeMissingCases (
336
343
whenExpression : FirWhenExpression ,
337
344
subjectType : ConeKotlinType ,
@@ -345,6 +352,10 @@ private object WhenOnBooleanExhaustivenessChecker : WhenExhaustivenessChecker()
345
352
}
346
353
347
354
val flags = Flags ()
355
+ (whenExpression.subjectVariable?.initializer as ? FirSmartCastExpression )
356
+ ?.lowerTypesFromSmartCast
357
+ ?.mapNotNull { (it as ? DfaType .BooleanLiteral )?.value }
358
+ ?.forEach { recordValue(it, flags) }
348
359
whenExpression.accept(ConditionChecker , flags)
349
360
if (! flags.containsTrue) {
350
361
destination.add(WhenMissingCase .BooleanIsMissing .TrueIsMissing )
@@ -356,21 +367,10 @@ private object WhenOnBooleanExhaustivenessChecker : WhenExhaustivenessChecker()
356
367
357
368
private object ConditionChecker : AbstractConditionChecker<Flags>() {
358
369
override fun visitEqualityOperatorCall (equalityOperatorCall : FirEqualityOperatorCall , data : Flags ) {
359
- fun recordValue (value : Any? ) = when (value) {
360
- true -> data.containsTrue = true
361
- false -> data.containsFalse = true
362
- else -> {}
363
- }
364
-
365
370
if (equalityOperatorCall.operation.let { it == FirOperation .EQ || it == FirOperation .IDENTITY }) {
366
- (equalityOperatorCall.arguments.firstOrNull() as ? FirSmartCastExpression )
367
- ?.lowerTypesFromSmartCast
368
- ?.mapNotNull { (it as ? DfaType .BooleanLiteral )?.value }
369
- ?.forEach(::recordValue)
370
-
371
371
val argument = equalityOperatorCall.arguments[1 ]
372
372
if (argument is FirLiteralExpression ) {
373
- recordValue(argument.value)
373
+ recordValue(argument.value, data )
374
374
}
375
375
}
376
376
}
@@ -393,6 +393,15 @@ private object WhenOnEnumExhaustivenessChecker : WhenExhaustivenessChecker() {
393
393
394
394
val enumClass = (subjectType.toSymbol(session) as FirRegularClassSymbol ).fir
395
395
val notCheckedEntries = enumClass.declarations.mapNotNullTo(mutableSetOf ()) { it as ? FirEnumEntry }
396
+
397
+ whenExpression.subjectVariable?.initializer?.let { initializer ->
398
+ val knownNonValues = (initializer as ? FirSmartCastExpression )
399
+ ?.lowerTypesFromSmartCast
400
+ ?.mapNotNull { (it as ? DfaType .Symbol )?.symbol?.fir }
401
+ .orEmpty()
402
+ notCheckedEntries.removeAll(knownNonValues)
403
+ }
404
+
396
405
whenExpression.accept(ConditionChecker , notCheckedEntries)
397
406
notCheckedEntries.mapTo(destination) { WhenMissingCase .EnumCheckIsMissing (it.symbol.callableId) }
398
407
}
@@ -402,12 +411,6 @@ private object WhenOnEnumExhaustivenessChecker : WhenExhaustivenessChecker() {
402
411
if (! equalityOperatorCall.operation.let { it == FirOperation .EQ || it == FirOperation .IDENTITY }) return
403
412
val argument = equalityOperatorCall.arguments[1 ]
404
413
405
- val knownNonValues = (equalityOperatorCall.arguments.firstOrNull() as ? FirSmartCastExpression )
406
- ?.lowerTypesFromSmartCast
407
- ?.mapNotNull { (it as ? DfaType .Symbol )?.symbol?.fir }
408
- .orEmpty()
409
- data.removeAll(knownNonValues)
410
-
411
414
@OptIn(UnsafeExpressionUtility ::class )
412
415
val symbol = argument.toResolvedCallableReferenceUnsafe()?.resolvedSymbol as ? FirVariableSymbol <* > ? : return
413
416
val checkedEnumEntry = symbol.fir as ? FirEnumEntry ? : return
@@ -429,7 +432,12 @@ private object WhenOnSealedClassExhaustivenessChecker : WhenExhaustivenessChecke
429
432
) {
430
433
val allSubclasses = subjectType.toSymbol(session)?.collectAllSubclasses(session) ? : return
431
434
val checkedSubclasses = mutableSetOf<FirBasedSymbol <* >>()
432
- whenExpression.accept(ConditionChecker , Flags (allSubclasses, checkedSubclasses, session))
435
+ val flags = Flags (allSubclasses, checkedSubclasses, session)
436
+
437
+ whenExpression.subjectVariable?.initializer?.let { initializer ->
438
+ inferVariantsFromSubjectSmartCast(initializer, flags)
439
+ }
440
+ whenExpression.accept(ConditionChecker , flags)
433
441
(allSubclasses - checkedSubclasses).mapNotNullTo(destination) {
434
442
when (it) {
435
443
is FirClassSymbol <* > -> WhenMissingCase .IsTypeCheckIsMissing (
@@ -449,14 +457,26 @@ private object WhenOnSealedClassExhaustivenessChecker : WhenExhaustivenessChecke
449
457
val session : FirSession
450
458
)
451
459
460
+ private fun inferVariantsFromSubjectSmartCast (subject : FirExpression , data : Flags ) {
461
+ if (subject !is FirSmartCastExpression ) return
462
+
463
+ for (knownNonType in subject.lowerTypesFromSmartCast) {
464
+ val symbol = when (knownNonType) {
465
+ is DfaType .Cone -> knownNonType.type.toSymbol(data.session)
466
+ is DfaType .Symbol -> knownNonType.symbol
467
+ else -> null
468
+ }
469
+ processBranch(symbol ? : continue , isNegated = false , data)
470
+ }
471
+ }
472
+
452
473
private object ConditionChecker : AbstractConditionChecker<Flags>() {
453
474
override fun visitEqualityOperatorCall (equalityOperatorCall : FirEqualityOperatorCall , data : Flags ) {
454
475
val isNegated = when (equalityOperatorCall.operation) {
455
476
FirOperation .EQ , FirOperation .IDENTITY -> false
456
477
FirOperation .NOT_EQ , FirOperation .NOT_IDENTITY -> true
457
478
else -> return
458
479
}
459
- inferVariantsFromSubjectSmartCast(equalityOperatorCall, data)
460
480
val symbol = when (val argument = equalityOperatorCall.arguments[1 ].unwrapSmartcastExpression()) {
461
481
is FirResolvedQualifier -> {
462
482
val firClass = (argument.symbol as ? FirRegularClassSymbol )?.fir
@@ -480,25 +500,11 @@ private object WhenOnSealedClassExhaustivenessChecker : WhenExhaustivenessChecke
480
500
FirOperation .NOT_IS -> true
481
501
else -> return
482
502
}
483
- inferVariantsFromSubjectSmartCast(typeOperatorCall, data)
484
503
val symbol = typeOperatorCall.conversionTypeRef.coneType.fullyExpandedType(data.session).toSymbol(data.session) ? : return
485
504
processBranch(symbol, isNegated, data)
486
505
}
487
506
488
- private fun inferVariantsFromSubjectSmartCast (typeOperatorCall : FirCall , data : Flags ) {
489
- val subject = typeOperatorCall.arguments.firstOrNull() as ? FirSmartCastExpression ? : return
490
-
491
- for (knownNonType in subject.lowerTypesFromSmartCast) {
492
- val symbol = when (knownNonType) {
493
- is DfaType .Cone -> knownNonType.type.toSymbol(data.session)
494
- is DfaType .Symbol -> knownNonType.symbol
495
- else -> null
496
- }
497
- processBranch(symbol ? : continue , isNegated = false , data)
498
- }
499
- }
500
-
501
- private fun processBranch (symbolToCheck : FirBasedSymbol <* >, isNegated : Boolean , flags : Flags ) {
507
+ fun processBranch (symbolToCheck : FirBasedSymbol <* >, isNegated : Boolean , flags : Flags ) {
502
508
val subclassesOfType = symbolToCheck.collectAllSubclasses(flags.session)
503
509
if (subclassesOfType.none { it in flags.allSubclasses }) {
504
510
return
0 commit comments