diff --git a/middle_end/flambda/types/env/typing_env_extension.rec.ml b/middle_end/flambda/types/env/typing_env_extension.rec.ml index 0ba572334b67..e9c6c0d42483 100644 --- a/middle_end/flambda/types/env/typing_env_extension.rec.ml +++ b/middle_end/flambda/types/env/typing_env_extension.rec.ml @@ -51,6 +51,9 @@ let empty () = { equations = Name.Map.empty; } let is_empty { equations } = Name.Map.is_empty equations +let from_map equations = + { equations; } + let one_equation name ty = Type_grammar.check_equation name ty; { equations = Name.Map.singleton name ty; } diff --git a/middle_end/flambda/types/env/typing_env_extension.rec.mli b/middle_end/flambda/types/env/typing_env_extension.rec.mli index 1f25ed83c3a9..ff684d3d0440 100644 --- a/middle_end/flambda/types/env/typing_env_extension.rec.mli +++ b/middle_end/flambda/types/env/typing_env_extension.rec.mli @@ -41,6 +41,8 @@ val is_empty : t -> bool val one_equation : Name.t -> Type_grammar.t -> t +val from_map : Type_grammar.t Name.Map.t -> t + val add_or_replace_equation : t -> Name.t -> Type_grammar.t -> t val meet : Meet_env.t -> t -> t -> t Or_bottom.t diff --git a/middle_end/flambda/types/env/typing_env_level.rec.ml b/middle_end/flambda/types/env/typing_env_level.rec.ml index c3cd9268f66c..1b5e84e06a05 100644 --- a/middle_end/flambda/types/env/typing_env_level.rec.ml +++ b/middle_end/flambda/types/env/typing_env_level.rec.ml @@ -186,41 +186,51 @@ let concat (t1 : t) (t2 : t) = symbol_projections; } -let join_types ~env_at_fork envs_with_levels ~extra_lifted_consts_in_use_envs = +let join_types ~env_at_fork envs_with_levels = (* Add all the variables defined by the branches as existentials to the [env_at_fork]. Any such variable will be given type [Unknown] on a branch where it was not originally present. Iterating on [level.binding_times] instead of [level.defined_vars] ensures - consistency of binding time order in the branches and the result. *) - let env_at_fork = - List.fold_left (fun env_at_fork (_, _, _, level) -> - Binding_time.Map.fold (fun _ vars env -> - Variable.Set.fold (fun var env -> - if Typing_env.mem env (Name.var var) then env - else - let kind = Variable.Map.find var level.defined_vars in - Typing_env.add_definition env - (Name_in_binding_pos.var - (Var_in_binding_pos.create var Name_mode.in_types)) - kind) - vars - env) - level.binding_times - env_at_fork) + consistency of binding time order in the branches and the result. + In addition, this also aggregates the code age relations of the branches. + *) + let base_env = + List.fold_left (fun base_env (env_at_use, _, _, level) -> + let base_env = + Binding_time.Map.fold (fun _ vars base_env -> + Variable.Set.fold (fun var base_env -> + if Typing_env.mem base_env (Name.var var) then base_env + else + let kind = Variable.Map.find var level.defined_vars in + Typing_env.add_definition base_env + (Name_in_binding_pos.var + (Var_in_binding_pos.create var Name_mode.in_types)) + kind) + vars + base_env) + level.binding_times + base_env + in + let code_age_relation = + Code_age_relation.union (Typing_env.code_age_relation base_env) + (Typing_env.code_age_relation env_at_use) + in + Typing_env.with_code_age_relation base_env code_age_relation) env_at_fork envs_with_levels in (* Now fold over the levels doing the actual join operation on equations. *) ListLabels.fold_left envs_with_levels - ~init:(env_at_fork, Name.Map.empty, true) - ~f:(fun (join_env, joined_types, is_first_join) (env_at_use, _, _, t) -> - let join_env = - Code_age_relation.union (Typing_env.code_age_relation join_env) - (Typing_env.code_age_relation env_at_use) - |> Typing_env.with_code_age_relation join_env + ~init:(Name.Map.empty, true) + ~f:(fun (joined_types, is_first_join) (env_at_use, _, _, t) -> + let left_env = + (* CR vlaviron: This is very likely quadratic (number of uses times + number of variables in all uses). + However it's hard to know how we could do better. *) + Typing_env.add_env_extension base_env + (Typing_env_extension.from_map joined_types) in - let next_join_env = ref join_env in let join_types name joined_ty use_ty = (* CR mshinwell for vlaviron: Looks like [Typing_env.mem] needs fixing with respect to names from other units with their @@ -229,22 +239,20 @@ let join_types ~env_at_fork envs_with_levels ~extra_lifted_consts_in_use_envs = Compilation_unit.equal (Name.compilation_unit name) (Compilation_unit.get_current_exn ()) in - if same_unit && not (Typing_env.mem env_at_fork name) then begin - Misc.fatal_errorf "Name %a not defined in [env_at_fork]:@ %a" + if same_unit && not (Typing_env.mem base_env name) then begin + Misc.fatal_errorf "Name %a not defined in [base_env]:@ %a" Name.print name - Typing_env.print env_at_fork + Typing_env.print base_env end; - let is_lifted_const_symbol = - match Name.must_be_symbol_opt name with - | None -> false - | Some symbol -> - Symbol.Set.mem symbol extra_lifted_consts_in_use_envs - in (* If [name] is that of a lifted constant symbol generated during one of the levels, then ignore it. [Simplify_expr] will already have - made its type suitable for [env_at_fork] and inserted it into that - environment. *) - if is_lifted_const_symbol then None + made its type suitable for [base_env] and inserted it into that + environment. + If [name] is a symbol that is not a lifted constant, then it was + defined before the fork and already has an equation in base_env. + While it is possible that its type could be refined by all of the + branches, it is unlikely. *) + if Name.is_symbol name then None else let joined_ty = match joined_ty, use_ty with @@ -267,18 +275,13 @@ let join_types ~env_at_fork envs_with_levels ~extra_lifted_consts_in_use_envs = to case split. *) else let expected_kind = Some (Type_grammar.kind use_ty) in - Typing_env.find env_at_fork name expected_kind + Typing_env.find base_env name expected_kind in - (* Recall: the order of environments matters for [join]. - Also note that we use [env_at_fork] not [env_at_use] for - the right-hand environment. This is done because there may - be names in types in [env_at_fork] that are not defined in - [env_at_use] -- see the comment in [check_join_inputs] - below. *) + (* Recall: the order of environments matters for [join]. *) let join_env = - Join_env.create env_at_fork - ~left_env:join_env - ~right_env:env_at_fork + Join_env.create base_env + ~left_env + ~right_env:env_at_use in Type_grammar.join ~bound_name:name join_env left_ty use_ty @@ -287,14 +290,14 @@ let join_types ~env_at_fork envs_with_levels ~extra_lifted_consts_in_use_envs = the current level for [name]. However we have seen an equation for [name] on a previous level. We need to get the best type we can for [name] on the current level, from - [env_at_fork], similarly to the previous case. *) + [base_env], similarly to the previous case. *) assert (not is_first_join); let expected_kind = Some (Type_grammar.kind joined_ty) in - let right_ty = Typing_env.find env_at_fork name expected_kind in + let right_ty = Typing_env.find base_env name expected_kind in let join_env = - Join_env.create env_at_fork - ~left_env:join_env - ~right_env:env_at_fork + Join_env.create base_env + ~left_env + ~right_env:base_env in Type_grammar.join ~bound_name:name join_env joined_ty right_ty @@ -304,8 +307,8 @@ let join_types ~env_at_fork envs_with_levels ~extra_lifted_consts_in_use_envs = equation for [name] on the current level. *) assert (not is_first_join); let join_env = - Join_env.create env_at_fork - ~left_env:join_env + Join_env.create base_env + ~left_env ~right_env:env_at_use in Type_grammar.join ~bound_name:name @@ -314,15 +317,13 @@ let join_types ~env_at_fork envs_with_levels ~extra_lifted_consts_in_use_envs = in begin match joined_ty with | Known joined_ty -> - next_join_env := - Typing_env.add_equation !next_join_env name joined_ty; Some joined_ty | Unknown -> None end in let joined_types = Name.Map.merge join_types joined_types t.equations in - !next_join_env, joined_types, false) - |> fun (_, joined_types, _) -> + joined_types, false) + |> fun (joined_types, _) -> joined_types let construct_joined_level envs_with_levels ~env_at_fork ~allowed @@ -426,7 +427,7 @@ let join ~env_at_fork envs_with_levels ~params ~extra_lifted_consts_in_use_envs; (* Calculate the joined types of all the names involved. *) let joined_types = - join_types ~env_at_fork envs_with_levels ~extra_lifted_consts_in_use_envs + join_types ~env_at_fork envs_with_levels in (* Next calculate which equations (describing joined types) to propagate to the join point. (Recall that the environment at the fork point includes