Skip to content

Casting collections to array types on the front-end #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,30 @@ val ConeKotlinType.isArrayTypeOrNullableArrayType: Boolean get() = isArrayType(i
val ConeKotlinType.isNonPrimitiveArray: Boolean
get() = this is ConeClassLikeType && lookupTag.classId == StandardClassIds.Array

val ConeKotlinType.isIntArray: Boolean
get() = this is ConeClassLikeType && lookupTag.classId == StandardClassIds.primitiveArrayTypeByElementType[StandardClassIds.Int]

val ConeKotlinType.isLongArray: Boolean
get() = this is ConeClassLikeType && lookupTag.classId == StandardClassIds.primitiveArrayTypeByElementType[StandardClassIds.Long]

val ConeKotlinType.isFloatArray: Boolean
get() = this is ConeClassLikeType && lookupTag.classId == StandardClassIds.primitiveArrayTypeByElementType[StandardClassIds.Float]

val ConeKotlinType.isDoubleArray: Boolean
get() = this is ConeClassLikeType && lookupTag.classId == StandardClassIds.primitiveArrayTypeByElementType[StandardClassIds.Double]

val ConeKotlinType.isCharArray: Boolean
get() = this is ConeClassLikeType && lookupTag.classId == StandardClassIds.primitiveArrayTypeByElementType[StandardClassIds.Char]

val ConeKotlinType.isByteArray: Boolean
get() = this is ConeClassLikeType && lookupTag.classId == StandardClassIds.primitiveArrayTypeByElementType[StandardClassIds.Byte]

val ConeKotlinType.isBooleanArray: Boolean
get() = this is ConeClassLikeType && lookupTag.classId == StandardClassIds.primitiveArrayTypeByElementType[StandardClassIds.Boolean]

val ConeKotlinType.isShortArray: Boolean
get() = this is ConeClassLikeType && lookupTag.classId == StandardClassIds.primitiveArrayTypeByElementType[StandardClassIds.Short]

val ConeKotlinType.isPrimitiveArray: Boolean
get() = this is ConeClassLikeType && lookupTag.classId in StandardClassIds.primitiveArrayTypeByElementType.values

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.jetbrains.kotlin.fir.types

import org.jetbrains.kotlin.name.StandardClassIds
import org.jetbrains.kotlin.name.canBeSpreaded
import org.jetbrains.kotlin.utils.addToStdlib.runIf

val ConeKotlinType.isArrayOrPrimitiveArray: Boolean
Expand Down Expand Up @@ -42,6 +43,23 @@ fun ConeKotlinType.arrayElementType(checkUnsignedArrays: Boolean = true): ConeKo
}
}

fun ConeKotlinType.spreadableCollectionElementType(checkUnsignedArrays: Boolean = true): ConeKotlinType? {
return when (val argument = spreadableCollectionElementTypeArgument(checkUnsignedArrays)) {
is ConeKotlinTypeProjection -> argument.type
else -> null
}
}

private fun ConeKotlinType.spreadableCollectionElementTypeArgument(checkUnsignedArrays: Boolean = true): ConeTypeProjection? {
val type = this.lowerBoundIfFlexible()
if (type !is ConeClassLikeType) return null
val classId = type.lookupTag.classId
if (classId.canBeSpreaded()) {
return type.typeArguments.first()
}
return arrayElementType(checkUnsignedArrays)
}

private fun ConeKotlinType.arrayElementTypeArgument(checkUnsignedArrays: Boolean = true): ConeTypeProjection? {
val type = this.lowerBoundIfFlexible()
if (type !is ConeClassLikeType) return null
Expand All @@ -55,7 +73,6 @@ private fun ConeKotlinType.arrayElementTypeArgument(checkUnsignedArrays: Boolean
if (elementType != null) {
return elementType.constructClassLikeType(emptyArray(), isNullable = false)
}

return null
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ fun Candidate.resolveArgumentExpression(
sink: CheckerSink,
context: ResolutionContext,
isReceiver: Boolean,
isDispatch: Boolean
isDispatch: Boolean,
isSpread: Boolean = false,
) {
when (argument) {
is FirFunctionCall, is FirWhenExpression, is FirTryExpression, is FirCheckNotNullCall, is FirElvisExpression -> resolveSubCallArgument(
Expand All @@ -56,7 +57,8 @@ fun Candidate.resolveArgumentExpression(
sink,
context,
isReceiver,
isDispatch
isDispatch,
isSpread = isSpread
)
// x?.bar() is desugared to `x SAFE-CALL-OPERATOR { $not-null-receiver$.bar() }`
//
Expand All @@ -75,7 +77,8 @@ fun Candidate.resolveArgumentExpression(
context,
isReceiver,
isDispatch,
useNullableArgumentType = true
useNullableArgumentType = true,
isSpread = isSpread
)
} else {
// Assignment
Expand All @@ -88,7 +91,8 @@ fun Candidate.resolveArgumentExpression(
isReceiver = false,
isDispatch = false,
sink = sink,
context = context
context = context,
isSpread = isSpread
)
}
}
Expand All @@ -101,7 +105,8 @@ fun Candidate.resolveArgumentExpression(
sink,
context,
isReceiver,
isDispatch
isDispatch,
isSpread = isSpread
)
else
preprocessCallableReference(argument, expectedType, context)
Expand All @@ -113,7 +118,8 @@ fun Candidate.resolveArgumentExpression(
sink,
context,
isReceiver,
isDispatch
isDispatch,
isSpread = isSpread
)
is FirBlock -> resolveBlockArgument(
csBuilder,
Expand All @@ -122,9 +128,10 @@ fun Candidate.resolveArgumentExpression(
sink,
context,
isReceiver,
isDispatch
isDispatch,
isSpread =isSpread
)
else -> resolvePlainExpressionArgument(csBuilder, argument, expectedType, sink, context, isReceiver, isDispatch)
else -> resolvePlainExpressionArgument(csBuilder, argument, expectedType, sink, context, isReceiver, isDispatch, isSpread = isSpread)
}
}

Expand All @@ -135,7 +142,8 @@ private fun Candidate.resolveBlockArgument(
sink: CheckerSink,
context: ResolutionContext,
isReceiver: Boolean,
isDispatch: Boolean
isDispatch: Boolean,
isSpread: Boolean = false
) {
val returnArguments = block.returnExpressions()
if (returnArguments.isEmpty()) {
Expand All @@ -148,7 +156,8 @@ private fun Candidate.resolveBlockArgument(
isReceiver = false,
isDispatch = false,
sink = sink,
context = context
context = context,
isSpread = isSpread
)
return
}
Expand All @@ -160,7 +169,8 @@ private fun Candidate.resolveBlockArgument(
sink,
context,
isReceiver,
isDispatch
isDispatch,
isSpread = isSpread
)
}
}
Expand All @@ -173,7 +183,8 @@ fun Candidate.resolveSubCallArgument(
context: ResolutionContext,
isReceiver: Boolean,
isDispatch: Boolean,
useNullableArgumentType: Boolean = false
useNullableArgumentType: Boolean = false,
isSpread: Boolean = false
) {
require(argument is FirExpression)
val candidate = argument.candidate() ?: return resolvePlainExpressionArgument(
Expand All @@ -184,7 +195,8 @@ fun Candidate.resolveSubCallArgument(
context,
isReceiver,
isDispatch,
useNullableArgumentType
useNullableArgumentType,
isSpread = isSpread
)
/*
* It's important to extract type from argument neither from symbol, because of symbol contains
Expand All @@ -201,7 +213,8 @@ fun Candidate.resolveSubCallArgument(
context,
isReceiver,
isDispatch,
useNullableArgumentType
useNullableArgumentType,
isSpread = isSpread
)
}

Expand Down Expand Up @@ -233,7 +246,8 @@ fun Candidate.resolvePlainExpressionArgument(
context: ResolutionContext,
isReceiver: Boolean,
isDispatch: Boolean,
useNullableArgumentType: Boolean = false
useNullableArgumentType: Boolean = false,
isSpread: Boolean = false
) {

if (expectedType == null) return
Expand All @@ -251,7 +265,8 @@ fun Candidate.resolvePlainExpressionArgument(
context,
isReceiver,
isDispatch,
useNullableArgumentType
useNullableArgumentType,
isSpread = isSpread
)
}

Expand All @@ -264,7 +279,8 @@ fun Candidate.resolvePlainArgumentType(
context: ResolutionContext,
isReceiver: Boolean,
isDispatch: Boolean,
useNullableArgumentType: Boolean = false
useNullableArgumentType: Boolean = false,
isSpread: Boolean = false
) {
val position = if (isReceiver) ConeReceiverConstraintPosition(argument) else ConeArgumentConstraintPosition(argument)

Expand All @@ -289,7 +305,7 @@ fun Candidate.resolvePlainArgumentType(
}

checkApplicabilityForArgumentType(
csBuilder, argument, argumentTypeForApplicabilityCheck, expectedType, position, isReceiver, isDispatch, sink, context
csBuilder, argument, argumentTypeForApplicabilityCheck, expectedType, position, isReceiver, isDispatch, sink, context, isSpread = isSpread
)
}

Expand Down Expand Up @@ -332,16 +348,28 @@ private fun checkApplicabilityForArgumentType(
csBuilder: ConstraintSystemBuilder,
argument: FirExpression,
argumentTypeBeforeCapturing: ConeKotlinType,
expectedType: ConeKotlinType?,
initialExpectedType: ConeKotlinType?,
position: ConstraintPosition,
isReceiver: Boolean,
isDispatch: Boolean,
sink: CheckerSink,
context: ResolutionContext
context: ResolutionContext,
isSpread: Boolean = false
) {
if (expectedType == null) return
var expectedType = initialExpectedType ?: return
var argumentType = captureFromTypeParameterUpperBoundIfNeeded(argumentTypeBeforeCapturing, expectedType, context.session)

if (isSpread && !argumentType.isNullable) {
argumentType = argumentType.spreadableCollectionElementType()?.also {
expectedType =
expectedType.arrayElementType()
?: error(
"Could not retrieve expected element type for vararg parameter."
+ " Parameter type is ${expectedType.renderReadable()}"
)
} ?: argumentType
}

val argumentType = captureFromTypeParameterUpperBoundIfNeeded(argumentTypeBeforeCapturing, expectedType, context.session)

fun subtypeError(actualExpectedType: ConeKotlinType): ResolutionDiagnostic {
if (argument.isNullLiteral && actualExpectedType.nullability == ConeNullability.NOT_NULL) {
Expand Down Expand Up @@ -413,6 +441,7 @@ private fun checkApplicabilityForArgumentType(
}
}


if (!isReceiver) {
sink.reportDiagnostic(subtypeError(expectedType))
return
Expand All @@ -427,6 +456,7 @@ private fun checkApplicabilityForArgumentType(
sink.reportDiagnostic(InapplicableWrongReceiver(expectedType, argumentType))
}
}

}

internal fun Candidate.resolveArgument(
Expand All @@ -449,7 +479,8 @@ internal fun Candidate.resolveArgument(
sink,
context,
isReceiver,
false
false,
isSpread = argument is FirSpreadArgumentExpression && parameter?.isVararg == true
)
}

Expand Down
Loading