Skip to content

Commit c3e8cb9

Browse files
lunakolySpace Team
authored and
Space Team
committed
[FIR] Avoid reanalyzing smartcast types in every when branch
The only new information that can appear among the smartcast types when moving from one `when` branch to another is exactly what we already collect into `flags`.
1 parent f5596b2 commit c3e8cb9

File tree

8 files changed

+127
-35
lines changed

8 files changed

+127
-35
lines changed

analysis/low-level-api-fir/tests/org/jetbrains/kotlin/analysis/low/level/api/fir/diagnostic/compiler/based/LLDiagnosticsFe10TestGenerated.java

+6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

analysis/low-level-api-fir/tests/org/jetbrains/kotlin/analysis/low/level/api/fir/diagnostic/compiler/based/LLReversedDiagnosticsFe10TestGenerated.java

+6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

compiler/fir/analysis-tests/tests-gen/org/jetbrains/kotlin/test/runners/FirLightTreeOldFrontendDiagnosticsWithLatestLanguageVersionTestGenerated.java

+6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

compiler/fir/analysis-tests/tests-gen/org/jetbrains/kotlin/test/runners/PhasedJvmDiagnosticLightTreeTestGenerated.java

+6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

compiler/fir/analysis-tests/tests-gen/org/jetbrains/kotlin/test/runners/PhasedJvmDiagnosticPsiTestGenerated.java

+6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/transformers/FirWhenExhaustivenessTransformer.kt

+41-35
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.jetbrains.kotlin.fir.resolve.fullyExpandedType
2323
import org.jetbrains.kotlin.fir.resolve.providers.symbolProvider
2424
import org.jetbrains.kotlin.fir.resolve.toRegularClassSymbol
2525
import org.jetbrains.kotlin.fir.resolve.toSymbol
26+
import org.jetbrains.kotlin.fir.resolve.transformers.WhenOnSealedClassExhaustivenessChecker.ConditionChecker.processBranch
2627
import org.jetbrains.kotlin.fir.symbols.FirBasedSymbol
2728
import org.jetbrains.kotlin.fir.symbols.impl.FirClassSymbol
2829
import org.jetbrains.kotlin.fir.symbols.impl.FirRegularClassSymbol
@@ -332,6 +333,12 @@ private object WhenOnBooleanExhaustivenessChecker : WhenExhaustivenessChecker()
332333
var containsFalse = false
333334
}
334335

336+
private fun recordValue(value: Any?, data: Flags) = when (value) {
337+
true -> data.containsTrue = true
338+
false -> data.containsFalse = true
339+
else -> {}
340+
}
341+
335342
override fun computeMissingCases(
336343
whenExpression: FirWhenExpression,
337344
subjectType: ConeKotlinType,
@@ -345,6 +352,10 @@ private object WhenOnBooleanExhaustivenessChecker : WhenExhaustivenessChecker()
345352
}
346353

347354
val flags = Flags()
355+
(whenExpression.subjectVariable?.initializer as? FirSmartCastExpression)
356+
?.lowerTypesFromSmartCast
357+
?.mapNotNull { (it as? DfaType.BooleanLiteral)?.value }
358+
?.forEach { recordValue(it, flags) }
348359
whenExpression.accept(ConditionChecker, flags)
349360
if (!flags.containsTrue) {
350361
destination.add(WhenMissingCase.BooleanIsMissing.TrueIsMissing)
@@ -356,21 +367,10 @@ private object WhenOnBooleanExhaustivenessChecker : WhenExhaustivenessChecker()
356367

357368
private object ConditionChecker : AbstractConditionChecker<Flags>() {
358369
override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: Flags) {
359-
fun recordValue(value: Any?) = when (value) {
360-
true -> data.containsTrue = true
361-
false -> data.containsFalse = true
362-
else -> {}
363-
}
364-
365370
if (equalityOperatorCall.operation.let { it == FirOperation.EQ || it == FirOperation.IDENTITY }) {
366-
(equalityOperatorCall.arguments.firstOrNull() as? FirSmartCastExpression)
367-
?.lowerTypesFromSmartCast
368-
?.mapNotNull { (it as? DfaType.BooleanLiteral)?.value }
369-
?.forEach(::recordValue)
370-
371371
val argument = equalityOperatorCall.arguments[1]
372372
if (argument is FirLiteralExpression) {
373-
recordValue(argument.value)
373+
recordValue(argument.value, data)
374374
}
375375
}
376376
}
@@ -393,6 +393,15 @@ private object WhenOnEnumExhaustivenessChecker : WhenExhaustivenessChecker() {
393393

394394
val enumClass = (subjectType.toSymbol(session) as FirRegularClassSymbol).fir
395395
val notCheckedEntries = enumClass.declarations.mapNotNullTo(mutableSetOf()) { it as? FirEnumEntry }
396+
397+
whenExpression.subjectVariable?.initializer?.let { initializer ->
398+
val knownNonValues = (initializer as? FirSmartCastExpression)
399+
?.lowerTypesFromSmartCast
400+
?.mapNotNull { (it as? DfaType.Symbol)?.symbol?.fir }
401+
.orEmpty()
402+
notCheckedEntries.removeAll(knownNonValues)
403+
}
404+
396405
whenExpression.accept(ConditionChecker, notCheckedEntries)
397406
notCheckedEntries.mapTo(destination) { WhenMissingCase.EnumCheckIsMissing(it.symbol.callableId) }
398407
}
@@ -402,12 +411,6 @@ private object WhenOnEnumExhaustivenessChecker : WhenExhaustivenessChecker() {
402411
if (!equalityOperatorCall.operation.let { it == FirOperation.EQ || it == FirOperation.IDENTITY }) return
403412
val argument = equalityOperatorCall.arguments[1]
404413

405-
val knownNonValues = (equalityOperatorCall.arguments.firstOrNull() as? FirSmartCastExpression)
406-
?.lowerTypesFromSmartCast
407-
?.mapNotNull { (it as? DfaType.Symbol)?.symbol?.fir }
408-
.orEmpty()
409-
data.removeAll(knownNonValues)
410-
411414
@OptIn(UnsafeExpressionUtility::class)
412415
val symbol = argument.toResolvedCallableReferenceUnsafe()?.resolvedSymbol as? FirVariableSymbol<*> ?: return
413416
val checkedEnumEntry = symbol.fir as? FirEnumEntry ?: return
@@ -429,7 +432,12 @@ private object WhenOnSealedClassExhaustivenessChecker : WhenExhaustivenessChecke
429432
) {
430433
val allSubclasses = subjectType.toSymbol(session)?.collectAllSubclasses(session) ?: return
431434
val checkedSubclasses = mutableSetOf<FirBasedSymbol<*>>()
432-
whenExpression.accept(ConditionChecker, Flags(allSubclasses, checkedSubclasses, session))
435+
val flags = Flags(allSubclasses, checkedSubclasses, session)
436+
437+
whenExpression.subjectVariable?.initializer?.let { initializer ->
438+
inferVariantsFromSubjectSmartCast(initializer, flags)
439+
}
440+
whenExpression.accept(ConditionChecker, flags)
433441
(allSubclasses - checkedSubclasses).mapNotNullTo(destination) {
434442
when (it) {
435443
is FirClassSymbol<*> -> WhenMissingCase.IsTypeCheckIsMissing(
@@ -449,14 +457,26 @@ private object WhenOnSealedClassExhaustivenessChecker : WhenExhaustivenessChecke
449457
val session: FirSession
450458
)
451459

460+
private fun inferVariantsFromSubjectSmartCast(subject: FirExpression, data: Flags) {
461+
if (subject !is FirSmartCastExpression) return
462+
463+
for (knownNonType in subject.lowerTypesFromSmartCast) {
464+
val symbol = when (knownNonType) {
465+
is DfaType.Cone -> knownNonType.type.toSymbol(data.session)
466+
is DfaType.Symbol -> knownNonType.symbol
467+
else -> null
468+
}
469+
processBranch(symbol ?: continue, isNegated = false, data)
470+
}
471+
}
472+
452473
private object ConditionChecker : AbstractConditionChecker<Flags>() {
453474
override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: Flags) {
454475
val isNegated = when (equalityOperatorCall.operation) {
455476
FirOperation.EQ, FirOperation.IDENTITY -> false
456477
FirOperation.NOT_EQ, FirOperation.NOT_IDENTITY -> true
457478
else -> return
458479
}
459-
inferVariantsFromSubjectSmartCast(equalityOperatorCall, data)
460480
val symbol = when (val argument = equalityOperatorCall.arguments[1].unwrapSmartcastExpression()) {
461481
is FirResolvedQualifier -> {
462482
val firClass = (argument.symbol as? FirRegularClassSymbol)?.fir
@@ -480,25 +500,11 @@ private object WhenOnSealedClassExhaustivenessChecker : WhenExhaustivenessChecke
480500
FirOperation.NOT_IS -> true
481501
else -> return
482502
}
483-
inferVariantsFromSubjectSmartCast(typeOperatorCall, data)
484503
val symbol = typeOperatorCall.conversionTypeRef.coneType.fullyExpandedType(data.session).toSymbol(data.session) ?: return
485504
processBranch(symbol, isNegated, data)
486505
}
487506

488-
private fun inferVariantsFromSubjectSmartCast(typeOperatorCall: FirCall, data: Flags) {
489-
val subject = typeOperatorCall.arguments.firstOrNull() as? FirSmartCastExpression ?: return
490-
491-
for (knownNonType in subject.lowerTypesFromSmartCast) {
492-
val symbol = when (knownNonType) {
493-
is DfaType.Cone -> knownNonType.type.toSymbol(data.session)
494-
is DfaType.Symbol -> knownNonType.symbol
495-
else -> null
496-
}
497-
processBranch(symbol ?: continue, isNegated = false, data)
498-
}
499-
}
500-
501-
private fun processBranch(symbolToCheck: FirBasedSymbol<*>, isNegated: Boolean, flags: Flags) {
507+
fun processBranch(symbolToCheck: FirBasedSymbol<*>, isNegated: Boolean, flags: Flags) {
502508
val subclassesOfType = symbolToCheck.collectAllSubclasses(flags.session)
503509
if (subclassesOfType.none { it in flags.allSubclasses }) {
504510
return
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// FIR_IDENTICAL
2+
// RUN_PIPELINE_TILL: FRONTEND
3+
4+
fun upperType(): Int {
5+
var a: Boolean? = false
6+
val block: () -> Unit = { a = null }
7+
8+
if (a == null) return 1
9+
return <!NO_ELSE_IN_WHEN!>when<!> (a) {
10+
true -> 2
11+
false -> 3
12+
}
13+
}
14+
15+
fun booleanLiteral(): Int {
16+
var a: Boolean = true
17+
val block: () -> Unit = { a = false }
18+
19+
if (a == false) return 1
20+
return <!NO_ELSE_IN_WHEN!>when<!> (a) {
21+
true -> 2
22+
}
23+
}
24+
25+
enum class EnumBoolean { False, True }
26+
27+
fun enumEntry(): Int {
28+
var a: EnumBoolean = EnumBoolean.True
29+
val block: () -> Unit = { a = EnumBoolean.False }
30+
31+
if (a == EnumBoolean.False) return 1
32+
return <!NO_ELSE_IN_WHEN!>when<!> (a) {
33+
EnumBoolean.True -> 2
34+
}
35+
}
36+
37+
sealed class SealedBoolean {
38+
data object True : SealedBoolean()
39+
data object False : SealedBoolean()
40+
}
41+
42+
fun sealedVariant(): Int {
43+
var a: SealedBoolean = SealedBoolean.True
44+
val block: () -> Unit = { a = SealedBoolean.False }
45+
46+
if (a == SealedBoolean.False) return 1
47+
return <!NO_ELSE_IN_WHEN!>when<!> (a) {
48+
is SealedBoolean.True -> 2
49+
}
50+
}

compiler/tests-common-new/tests-gen/org/jetbrains/kotlin/test/runners/DiagnosticTestGenerated.java

+6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)