Skip to content

Commit 800424e

Browse files
committed
Properly infer specific enum values, fix test
1 parent c7e2c16 commit 800424e

File tree

2 files changed

+79
-19
lines changed

2 files changed

+79
-19
lines changed

src/Type/Doctrine/Query/QueryResultTypeWalker.php

Lines changed: 72 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
use PHPStan\Type\Constant\ConstantBooleanType;
2828
use PHPStan\Type\Constant\ConstantFloatType;
2929
use PHPStan\Type\Constant\ConstantIntegerType;
30+
use PHPStan\Type\Constant\ConstantStringType;
3031
use PHPStan\Type\ConstantTypeHelper;
3132
use PHPStan\Type\Doctrine\DescriptorNotRegisteredException;
3233
use PHPStan\Type\Doctrine\DescriptorRegistry;
@@ -54,6 +55,7 @@
5455
use function get_class;
5556
use function gettype;
5657
use function in_array;
58+
use function is_array;
5759
use function is_int;
5860
use function is_numeric;
5961
use function is_object;
@@ -287,13 +289,13 @@ public function walkPathExpression($pathExpr): string
287289

288290
switch ($pathExpr->type) {
289291
case AST\PathExpression::TYPE_STATE_FIELD:
290-
[$typeName, $enumType] = $this->getTypeOfField($class, $fieldName);
292+
[$typeName, $enumType, $enumValues] = $this->getTypeOfField($class, $fieldName);
291293

292294
$nullable = $this->isQueryComponentNullable($dqlAlias)
293295
|| $class->isNullable($fieldName)
294296
|| $this->hasAggregateWithoutGroupBy();
295297

296-
$fieldType = $this->resolveDatabaseInternalType($typeName, $enumType, $nullable);
298+
$fieldType = $this->resolveDatabaseInternalType($typeName, $enumType, $enumValues, $nullable);
297299

298300
return $this->marshalType($fieldType);
299301

@@ -327,12 +329,12 @@ public function walkPathExpression($pathExpr): string
327329
}
328330

329331
$targetFieldName = $identifierFieldNames[0];
330-
[$typeName, $enumType] = $this->getTypeOfField($targetClass, $targetFieldName);
332+
[$typeName, $enumType, $enumValues] = $this->getTypeOfField($targetClass, $targetFieldName);
331333

332334
$nullable = ($joinColumn['nullable'] ?? true)
333335
|| $this->hasAggregateWithoutGroupBy();
334336

335-
$fieldType = $this->resolveDatabaseInternalType($typeName, $enumType, $nullable);
337+
$fieldType = $this->resolveDatabaseInternalType($typeName, $enumType, $enumValues, $nullable);
336338

337339
return $this->marshalType($fieldType);
338340

@@ -686,7 +688,7 @@ public function walkFunction($function): string
686688
return $this->marshalType(new MixedType());
687689
}
688690

689-
[$typeName, $enumType] = $this->getTypeOfField($targetClass, $targetFieldName);
691+
[$typeName, $enumType, $enumValues] = $this->getTypeOfField($targetClass, $targetFieldName);
690692

691693
if (!isset($assoc['joinColumns'])) {
692694
return $this->marshalType(new MixedType());
@@ -709,7 +711,7 @@ public function walkFunction($function): string
709711
|| $this->isQueryComponentNullable($dqlAlias)
710712
|| $this->hasAggregateWithoutGroupBy();
711713

712-
$fieldType = $this->resolveDatabaseInternalType($typeName, $enumType, $nullable);
714+
$fieldType = $this->resolveDatabaseInternalType($typeName, $enumType, $enumValues, $nullable);
713715

714716
return $this->marshalType($fieldType);
715717

@@ -1207,13 +1209,13 @@ public function walkSelectExpression($selectExpression): string
12071209
assert(array_key_exists('metadata', $qComp));
12081210
$class = $qComp['metadata'];
12091211

1210-
[$typeName, $enumType] = $this->getTypeOfField($class, $fieldName);
1212+
[$typeName, $enumType, $enumValues] = $this->getTypeOfField($class, $fieldName);
12111213

12121214
$nullable = $this->isQueryComponentNullable($dqlAlias)
12131215
|| $class->isNullable($fieldName)
12141216
|| $this->hasAggregateWithoutGroupBy();
12151217

1216-
$type = $this->resolveDoctrineType($typeName, $enumType, $nullable);
1218+
$type = $this->resolveDoctrineType($typeName, $enumType, $enumValues, $nullable);
12171219

12181220
$this->typeBuilder->addScalar($resultAlias, $type);
12191221

@@ -1241,7 +1243,7 @@ public function walkSelectExpression($selectExpression): string
12411243
$dbalTypeName = DbalType::getTypeRegistry()->lookupName($expr->getReturnType());
12421244
$type = TypeCombinator::intersect( // e.g. count is typed as int, but we infer int<0, max>
12431245
$type,
1244-
$this->resolveDoctrineType($dbalTypeName, null, TypeCombinator::containsNull($type)),
1246+
$this->resolveDoctrineType($dbalTypeName, null, null, TypeCombinator::containsNull($type)),
12451247
);
12461248

12471249
if ($this->hasAggregateWithoutGroupBy() && !$expr instanceof AST\Functions\CountFunction) {
@@ -1999,7 +2001,7 @@ private function isQueryComponentNullable(string $dqlAlias): bool
19992001

20002002
/**
20012003
* @param ClassMetadata<object> $class
2002-
* @return array{string, ?class-string<BackedEnum>} Doctrine type name and enum type of field
2004+
* @return array{string, ?class-string<BackedEnum>, ?list<string>} Doctrine type name, enum type of field, enum values
20032005
*/
20042006
private function getTypeOfField(ClassMetadata $class, string $fieldName): array
20052007
{
@@ -2017,11 +2019,45 @@ private function getTypeOfField(ClassMetadata $class, string $fieldName): array
20172019
$enumType = null;
20182020
}
20192021

2020-
return [$type, $enumType];
2022+
return [$type, $enumType, $this->detectEnumValues($type, $metadata)];
20212023
}
20222024

2023-
/** @param ?class-string<BackedEnum> $enumType */
2024-
private function resolveDoctrineType(string $typeName, ?string $enumType = null, bool $nullable = false): Type
2025+
/**
2026+
* @param mixed $metadata
2027+
*
2028+
* @return list<string>|null
2029+
*/
2030+
private function detectEnumValues(string $typeName, $metadata): ?array
2031+
{
2032+
if ($typeName !== 'enum') {
2033+
return null;
2034+
}
2035+
2036+
$values = $metadata['options']['values'] ?? [];
2037+
2038+
if (!is_array($values) || count($values) === 0) {
2039+
return null;
2040+
}
2041+
2042+
foreach ($values as $value) {
2043+
if (!is_string($value)) {
2044+
return null;
2045+
}
2046+
}
2047+
2048+
return array_values($values);
2049+
}
2050+
2051+
/**
2052+
* @param ?class-string<BackedEnum> $enumType
2053+
* @param ?list<string> $enumValues
2054+
*/
2055+
private function resolveDoctrineType(
2056+
string $typeName,
2057+
?string $enumType = null,
2058+
?array $enumValues = null,
2059+
bool $nullable = false
2060+
): Type
20252061
{
20262062
try {
20272063
$type = $this->descriptorRegistry
@@ -2038,8 +2074,14 @@ private function resolveDoctrineType(string $typeName, ?string $enumType = null,
20382074
), ...TypeUtils::getAccessoryTypes($type));
20392075
}
20402076
}
2077+
2078+
if ($enumValues !== null) {
2079+
$enumValuesType = TypeCombinator::union(...array_map(static fn (string $value) => new ConstantStringType($value), $enumValues));
2080+
$type = TypeCombinator::intersect($enumValuesType, $type);
2081+
}
2082+
20412083
if ($type instanceof NeverType) {
2042-
$type = new MixedType();
2084+
$type = new MixedType();
20432085
}
20442086
} catch (DescriptorNotRegisteredException $e) {
20452087
if ($enumType !== null) {
@@ -2053,11 +2095,19 @@ private function resolveDoctrineType(string $typeName, ?string $enumType = null,
20532095
$type = TypeCombinator::addNull($type);
20542096
}
20552097

2056-
return $type;
2098+
return $type;
20572099
}
20582100

2059-
/** @param ?class-string<BackedEnum> $enumType */
2060-
private function resolveDatabaseInternalType(string $typeName, ?string $enumType = null, bool $nullable = false): Type
2101+
/**
2102+
* @param ?class-string<BackedEnum> $enumType
2103+
* @param ?list<string> $enumValues
2104+
*/
2105+
private function resolveDatabaseInternalType(
2106+
string $typeName,
2107+
?string $enumType = null,
2108+
?array $enumValues = null,
2109+
bool $nullable = false
2110+
): Type
20612111
{
20622112
try {
20632113
$descriptor = $this->descriptorRegistry->get($typeName);
@@ -2076,6 +2126,11 @@ private function resolveDatabaseInternalType(string $typeName, ?string $enumType
20762126
$type = TypeCombinator::intersect($enumType, $type);
20772127
}
20782128

2129+
if ($enumValues !== null) {
2130+
$enumValuesType = TypeCombinator::union(...array_map(static fn (string $value) => new ConstantStringType($value), $enumValues));
2131+
$type = TypeCombinator::intersect($enumValuesType, $type);
2132+
}
2133+
20792134
if ($nullable) {
20802135
$type = TypeCombinator::addNull($type);
20812136
}

tests/Type/Doctrine/Query/QueryResultTypeWalkerTest.php

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
use PHPStan\Type\StringType;
3636
use PHPStan\Type\Type;
3737
use PHPStan\Type\TypeCombinator;
38+
use PHPStan\Type\UnionType;
3839
use PHPStan\Type\VerbosityLevel;
3940
use QueryResult\Entities\Embedded;
4041
use QueryResult\Entities\JoinedChild;
@@ -1547,7 +1548,11 @@ private function yieldConditionalDataset(): iterable
15471548
$this->constantArray([
15481549
[
15491550
new ConstantStringType('enum'),
1550-
new StringType(),
1551+
new UnionType([
1552+
new ConstantStringType('a'),
1553+
new ConstantStringType('b'),
1554+
new ConstantStringType('c'),
1555+
]),
15511556
],
15521557
[
15531558
new ConstantStringType('smallfloat'),
@@ -1556,7 +1561,7 @@ private function yieldConditionalDataset(): iterable
15561561
]),
15571562
'
15581563
SELECT e.enum, e.smallfloat
1559-
FROM QueryResult\Entities\Dbal4Entity e
1564+
FROM QueryResult\EntitiesDbal42\Dbal4Entity e
15601565
',
15611566
];
15621567
}

0 commit comments

Comments
 (0)