diff --git a/doc/missed_optimisations.md b/doc/missed_optimisations.md new file mode 100644 index 000000000000..3e8df3950873 --- /dev/null +++ b/doc/missed_optimisations.md @@ -0,0 +1,8 @@ +# List of missed optimisations to look at later + +This should be compiled to the identity. But the CSE environment seems to lose +the relationship in the join after the first switch. + + let f x = + let not_x = if x then false else true in + if not_x then false else true diff --git a/flambdatest/mlexamples/involutions.ml b/flambdatest/mlexamples/involutions.ml new file mode 100644 index 000000000000..3cac5398fc69 --- /dev/null +++ b/flambdatest/mlexamples/involutions.ml @@ -0,0 +1,18 @@ + +let f x = + not (not (not x)) + +let g x = + ~- (~- (~- x)) + +let h x = + Int64.(neg (neg (neg x))) + +let i x = + Int32.(neg (neg (neg x))) + +let j x = + Nativeint.(neg (neg (neg x))) + +let k x = + ~-. (~-. (~-. x)) diff --git a/middle_end/flambda/simplify/simplify_primitive.ml b/middle_end/flambda/simplify/simplify_primitive.ml index 5ecb6ef767a3..d3d898289427 100644 --- a/middle_end/flambda/simplify/simplify_primitive.ml +++ b/middle_end/flambda/simplify/simplify_primitive.ml @@ -29,13 +29,16 @@ let apply_cse dacc ~original_prim = match DE.find_cse (DA.denv dacc) with_fixed_value with | None -> None | Some simple -> - let canonical = + match TE.get_canonical_simple_exn (DA.typing_env dacc) simple ~min_name_mode:NM.normal ~name_mode_of_existing_simple:NM.normal - in - match canonical with - | exception Not_found -> None + with + | exception Not_found -> + (* CR pchambart: this exception was not caught for some time, it is + expected never to happen, hence the fatal_error *) + Misc.fatal_errorf "No canonical simple for the CSE candidate: %a" + Simple.print simple | simple -> Some simple let try_cse dacc ~original_prim ~simplified_args_with_tys ~min_name_mode @@ -47,12 +50,13 @@ let try_cse dacc ~original_prim ~simplified_args_with_tys ~min_name_mode if not (Name_mode.equal min_name_mode Name_mode.normal) then Not_applied dacc else let result_var = VB.var result_var in + let args = List.map fst simplified_args_with_tys in + let original_prim = P.update_args original_prim args in match apply_cse dacc ~original_prim with | Some replace_with -> let named = Named.create_simple replace_with in let ty = T.alias_type_of (P.result_kind' original_prim) replace_with in let env_extension = TEE.one_equation (Name.var result_var) ty in - let args = List.map fst simplified_args_with_tys in let simplified_named = let cost_metrics = Cost_metrics.notify_removed diff --git a/middle_end/flambda/simplify/simplify_unary_primitive.ml b/middle_end/flambda/simplify/simplify_unary_primitive.ml index c730f6292709..6069c3620e9f 100644 --- a/middle_end/flambda/simplify/simplify_unary_primitive.ml +++ b/middle_end/flambda/simplify/simplify_unary_primitive.ml @@ -507,4 +507,14 @@ let simplify_unary_primitive dacc (prim : P.unary_primitive) let reachable, env_extension, dacc = simplifier dacc ~original_term ~arg ~arg_ty ~result_var in + let dacc = + if P.is_an_involution prim then + DA.map_denv dacc ~f:(fun denv -> + let prim : P.t = + Unary (prim, Simple.var (Var_in_binding_pos.var result_var)) + in + DE.add_cse denv (P.Eligible_for_cse.create_exn prim) ~bound_to:arg) + else + dacc + in reachable, env_extension, [arg], dacc diff --git a/middle_end/flambda/terms/flambda_primitive.ml b/middle_end/flambda/terms/flambda_primitive.ml index a133f2557234..c95f76f51119 100644 --- a/middle_end/flambda/terms/flambda_primitive.ml +++ b/middle_end/flambda/terms/flambda_primitive.ml @@ -944,6 +944,18 @@ let unary_classify_for_printing p = | Select_closure _ | Project_var _ -> Destructive +let is_an_involution (p : unary_primitive) = + match p with + | Boolean_not -> true + | Int_arith (_, Neg) -> true + | Float_arith Neg -> true + (* Checked through using the following SMT formula with Z3 + {[ (set-logic QF_FP) + (declare-fun x () (_ FloatingPoint 11 53)) + (assert (not (= x (fp.neg (fp.neg x))))) + (check-sat) ]} *) + | _ -> false + type binary_int_arith_op = | Add | Sub | Mul | Div | Mod | And | Or | Xor @@ -1626,6 +1638,19 @@ let args t = | Ternary (_, x0, x1, x2) -> [x0; x1; x2] | Variadic (_, xs) -> xs +let update_args t args = + match t, args with + | Nullary _, [] -> t + | Unary (p, _), [x0] -> Unary (p, x0) + | Binary (p, _, _), [x0; x1] -> Binary (p, x0, x1) + | Ternary (p, _, _, _), [x0; x1; x2] -> Ternary (p, x0, x1, x2) + | Variadic (p, l), xs -> + assert(List.length l = List.length xs); + Variadic (p, xs) + | _, _ -> + Misc.fatal_errorf "Wrong arity for updating primitive %a with arguments %a" + print t Simple.List.print args + let result_kind (t : t) = match t with | Nullary prim -> result_kind_of_nullary_primitive prim diff --git a/middle_end/flambda/terms/flambda_primitive.mli b/middle_end/flambda/terms/flambda_primitive.mli index 240065f22efc..d84468709851 100644 --- a/middle_end/flambda/terms/flambda_primitive.mli +++ b/middle_end/flambda/terms/flambda_primitive.mli @@ -334,6 +334,8 @@ include Contains_ids.S with type t := t val args : t -> Simple.t list +val update_args : t -> Simple.t list -> t + (** Simpler version (e.g. for [Inlining_cost]), where only the actual primitive matters, not the arguments. *) module Without_args : sig @@ -404,6 +406,10 @@ val at_most_generative_effects : t -> bool and no other effects. *) val only_generative_effects : t -> bool +(** Returns [true] iff the primitive is an involution. i.e. + x = Prim(y) iff y = Prim(x) *) +val is_an_involution : unary_primitive -> bool + module Eligible_for_cse : sig (** Primitive applications that may be replaced by a variable which is let bound to a single instance of such application. Primitives that are diff --git a/testsuite/tests/flambda2/involution.ml b/testsuite/tests/flambda2/involution.ml new file mode 100644 index 000000000000..6138f216bfa8 --- /dev/null +++ b/testsuite/tests/flambda2/involution.ml @@ -0,0 +1,45 @@ +(* TEST +* flambda +** native +*) + +(* This tests that involutions are effectively detected. + It checks that for an involution f, f (f x) == x + This is verified by testing that the result of the physical equality + is a known constant, through an allocation test. *) + +external minor_words : unit -> (float [@unboxed]) + = "caml_gc_minor_words" "caml_gc_minor_words_unboxed" + +let[@inline never] check_no_alloc test_name f x = + let before = minor_words () in + let _ = Sys.opaque_identity ((f[@inlined never]) x) in + let after = minor_words () in + let diff = after -. before in + if diff = 0. then + Format.printf "No allocs for test '%s'@." test_name + else + Format.printf "Some allocs for test '%s'@." test_name + +let () = + let[@inline] test_known f = + let tester x = + let x = Sys.opaque_identity x in + let n = if (f[@inlined hint]) ((f[@inlined hint]) x) == x then 1 else 0 in + Int64.of_int n + in + let () = Sys.opaque_identity () in + tester + in + + (* Test the test: this should allocate *) + check_no_alloc "not an involution" (test_known (succ)) 1; + (* Actual tests *) + check_no_alloc "not" (test_known (not)) true; + check_no_alloc "~-." (test_known (~-.)) 42.; + check_no_alloc "~-" (test_known (~-)) 42; + check_no_alloc "Int32.neg" (test_known Int32.neg) 42l; + check_no_alloc "Int64.neg" (test_known Int64.neg) 42L; + check_no_alloc "Nativeint.neg" (test_known Nativeint.neg) 42n; + () +