Skip to content

Commit 1f504d9

Browse files
committed
Add hack for bool overload.
1 parent 67e2855 commit 1f504d9

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

gel/_internal/_codegen/_models/_pydantic.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3722,16 +3722,67 @@ def _write_potentially_overlapping_overloads(
37223722
# SEE ABOVE: This is what we actually want.
37233723
# key=lambda o: (generality_key(o), o.edgeql_signature), # noqa: ERA001, E501
37243724
)
3725+
base_generic_overload: dict[_Callable_T, _Callable_T] = {}
37253726

37263727
for overload in overloads:
37273728
overload_signatures[overload] = {}
3729+
3730+
if overload.schemapath == SchemaPath('std', 'IF'):
3731+
# HACK: Pretend the base overload of std::IF is generic on
3732+
# anyobject.
3733+
#
3734+
# The base overload of std::IF is
3735+
# (anytype, std::bool, anytype) -> anytype
3736+
#
3737+
# However, this causes an overlap with overloading for bool
3738+
# arguments since
3739+
# (anytype, builtin.bool, anytype) -> anytype
3740+
# overlaps with
3741+
# (std::bool, builtin.bool, std::bool) -> std::bool
3742+
#
3743+
# We resolve this by generating the specializations for anytype
3744+
# but using anyobject as the base generic type.
3745+
3746+
def anytype_to_anyobject(
3747+
refl_type: reflection.Type,
3748+
default: reflection.Type | reflection.TypeRef,
3749+
) -> reflection.Type | reflection.TypeRef:
3750+
if isinstance(refl_type, reflection.PseudoType):
3751+
return self._types_by_name["anyobject"]
3752+
return default
3753+
3754+
base_generic_overload[overload] = dataclasses.replace(
3755+
overload,
3756+
params=[
3757+
dataclasses.replace(
3758+
param,
3759+
type=anytype_to_anyobject(
3760+
param.get_type(self._types), param.type
3761+
),
3762+
)
3763+
for param in overload.params
3764+
],
3765+
return_type=anytype_to_anyobject(
3766+
overload.get_return_type(self._types),
3767+
overload.return_type,
3768+
),
3769+
)
3770+
37283771
for param in param_getter(overload):
37293772
param_overload_map[param.key].add(overload)
37303773
param_type = param.get_type(self._types)
37313774
# Unwrap the variadic type (it is reflected as an array of T)
37323775
if param.kind is reflection.CallableParamKind.Variadic:
37333776
if reflection.is_array_type(param_type):
37343777
param_type = param_type.get_element_type(self._types)
3778+
3779+
if (
3780+
overload.schemapath == SchemaPath('std', 'IF')
3781+
and param_type.is_pseudo
3782+
):
3783+
# Also generate the base signature using anyobject
3784+
param_type = self._types_by_name["anyobject"]
3785+
37353786
# Start with the base parameter type
37363787
overload_signatures[overload][param.key] = [param_type]
37373788

@@ -3843,7 +3894,10 @@ def specialization_sort_key(t: reflection.Type) -> int:
38433894
for overload in overloads:
38443895
if overload_specs := overloads_specializations.get(overload):
38453896
expanded_overloads.extend(overload_specs)
3846-
expanded_overloads.append(overload)
3897+
if overload in base_generic_overload:
3898+
expanded_overloads.append(base_generic_overload[overload])
3899+
else:
3900+
expanded_overloads.append(overload)
38473901
overloads = expanded_overloads
38483902

38493903
overload_order = {overload: i for i, overload in enumerate(overloads)}

gel/_internal/_reflection/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
ScalarType,
6969
TupleType,
7070
Type,
71+
TypeRef,
7172
compare_type_generality,
7273
fetch_types,
7374
is_abstract_type,
@@ -126,6 +127,7 @@
126127
"Type",
127128
"TypeKind",
128129
"TypeModifier",
130+
"TypeRef",
129131
"compare_callable_generality",
130132
"compare_type_generality",
131133
"fetch_branch_state",

0 commit comments

Comments
 (0)