@@ -3722,16 +3722,67 @@ def _write_potentially_overlapping_overloads(
37223722 # SEE ABOVE: This is what we actually want.
37233723 # key=lambda o: (generality_key(o), o.edgeql_signature), # noqa: ERA001, E501
37243724 )
3725+ base_generic_overload : dict [_Callable_T , _Callable_T ] = {}
37253726
37263727 for overload in overloads :
37273728 overload_signatures [overload ] = {}
3729+
3730+ if overload .schemapath == SchemaPath ('std' , 'IF' ):
3731+ # HACK: Pretend the base overload of std::IF is generic on
3732+ # anyobject.
3733+ #
3734+ # The base overload of std::IF is
3735+ # (anytype, std::bool, anytype) -> anytype
3736+ #
3737+ # However, this causes an overlap with overloading for bool
3738+ # arguments since
3739+ # (anytype, builtin.bool, anytype) -> anytype
3740+ # overlaps with
3741+ # (std::bool, builtin.bool, std::bool) -> std::bool
3742+ #
3743+ # We resolve this by generating the specializations for anytype
3744+ # but using anyobject as the base generic type.
3745+
3746+ def anytype_to_anyobject (
3747+ refl_type : reflection .Type ,
3748+ default : reflection .Type | reflection .TypeRef ,
3749+ ) -> reflection .Type | reflection .TypeRef :
3750+ if isinstance (refl_type , reflection .PseudoType ):
3751+ return self ._types_by_name ["anyobject" ]
3752+ return default
3753+
3754+ base_generic_overload [overload ] = dataclasses .replace (
3755+ overload ,
3756+ params = [
3757+ dataclasses .replace (
3758+ param ,
3759+ type = anytype_to_anyobject (
3760+ param .get_type (self ._types ), param .type
3761+ ),
3762+ )
3763+ for param in overload .params
3764+ ],
3765+ return_type = anytype_to_anyobject (
3766+ overload .get_return_type (self ._types ),
3767+ overload .return_type ,
3768+ ),
3769+ )
3770+
37283771 for param in param_getter (overload ):
37293772 param_overload_map [param .key ].add (overload )
37303773 param_type = param .get_type (self ._types )
37313774 # Unwrap the variadic type (it is reflected as an array of T)
37323775 if param .kind is reflection .CallableParamKind .Variadic :
37333776 if reflection .is_array_type (param_type ):
37343777 param_type = param_type .get_element_type (self ._types )
3778+
3779+ if (
3780+ overload .schemapath == SchemaPath ('std' , 'IF' )
3781+ and param_type .is_pseudo
3782+ ):
3783+ # Also generate the base signature using anyobject
3784+ param_type = self ._types_by_name ["anyobject" ]
3785+
37353786 # Start with the base parameter type
37363787 overload_signatures [overload ][param .key ] = [param_type ]
37373788
@@ -3843,7 +3894,10 @@ def specialization_sort_key(t: reflection.Type) -> int:
38433894 for overload in overloads :
38443895 if overload_specs := overloads_specializations .get (overload ):
38453896 expanded_overloads .extend (overload_specs )
3846- expanded_overloads .append (overload )
3897+ if overload in base_generic_overload :
3898+ expanded_overloads .append (base_generic_overload [overload ])
3899+ else :
3900+ expanded_overloads .append (overload )
38473901 overloads = expanded_overloads
38483902
38493903 overload_order = {overload : i for i , overload in enumerate (overloads )}
0 commit comments