diff --git a/arrayjit/lib/assignments.ml b/arrayjit/lib/assignments.ml index 16eb68c5..dc67d9c4 100644 --- a/arrayjit/lib/assignments.ml +++ b/arrayjit/lib/assignments.ml @@ -24,6 +24,16 @@ and t = | Noop | Seq of t * t | Block_comment of string * t (** Same as the given code, with a comment. *) + | Accum_ternop of { + initialize_neutral : bool; + accum : Ops.binop; + op : Ops.ternop; + lhs : Tn.t; + rhs1 : buffer; + rhs2 : buffer; + rhs3 : buffer; + projections : Indexing.projections Lazy.t; + } | Accum_binop of { initialize_neutral : bool; accum : Ops.binop; @@ -93,6 +103,8 @@ let%debug3_sexp context_nodes ~(use_host_memory : 'a option) (asgns : t) : Tn.t_ | Accum_unop { lhs; rhs; _ } -> Set.union (one lhs) (of_node rhs) | Accum_binop { lhs; rhs1; rhs2; _ } -> Set.union_list (module Tn) [ one lhs; of_node rhs1; of_node rhs2 ] + | Accum_ternop { lhs; rhs1; rhs2; rhs3; _ } -> + Set.union_list (module Tn) [ one lhs; of_node rhs1; of_node rhs2; of_node rhs3 ] | Fetch { array; _ } -> one array in loop asgns @@ -139,98 +151,60 @@ let%diagn2_sexp to_low_level code = assert (Array.length idcs = Array.length (Lazy.force tn.Tn.dims)); Low_level.Set { tn; idcs; llv; debug = "" } in - let rec loop code = + let rec loop_accum ~initialize_neutral ~accum ~op ~lhs ~rhses projections = + let projections = Lazy.force projections in + let lhs_idx = + derive_index ~product_syms:projections.product_iterators ~projection:projections.project_lhs + in + let rhs_idcs = + Array.map projections.project_rhs ~f:(fun projection -> + derive_index ~product_syms:projections.product_iterators ~projection) + in + let basecase rev_iters = + let product = Array.of_list_rev_map rev_iters ~f:(fun s -> Indexing.Iterator s) in + let rhses_idcs = Array.map rhs_idcs ~f:(fun rhs_idx -> rhs_idx ~product) in + let lhs_idcs = lhs_idx ~product in + let open Low_level in + let lhs_ll = get (Node lhs) lhs_idcs in + let rhses_ll = Array.mapi rhses_idcs ~f:(fun i rhs_idcs -> get rhses.(i) rhs_idcs) in + let rhs2 = apply_op op rhses_ll in + if is_total ~initialize_neutral ~projections then set lhs lhs_idcs rhs2 + else set lhs lhs_idcs @@ apply_op (Ops.Binop accum) [| lhs_ll; rhs2 |] + in + let rec for_loop rev_iters = function + | [] -> basecase rev_iters + | d :: product -> + let index = Indexing.get_symbol () in + For_loop + { + index; + from_ = 0; + to_ = d - 1; + body = for_loop (index :: rev_iters) product; + trace_it = true; + } + in + let for_loops = + try for_loop [] (Array.to_list projections.product_space) + with e -> + [%log "projections=", (projections : projections)]; + raise e + in + if initialize_neutral && not (is_total ~initialize_neutral ~projections) then + let dims = lazy projections.lhs_dims in + let fetch_op = Constant (Ops.neutral_elem accum) in + Low_level.Seq (loop (Fetch { array = lhs; fetch_op; dims }), for_loops) + else for_loops + and loop code = match code with + | Accum_ternop { initialize_neutral; accum; op; lhs; rhs1; rhs2; rhs3; projections } -> + loop_accum ~initialize_neutral ~accum ~op:(Ops.Ternop op) ~lhs ~rhses:[| rhs1; rhs2; rhs3 |] + projections | Accum_binop { initialize_neutral; accum; op; lhs; rhs1; rhs2; projections } -> - let projections = Lazy.force projections in - let lhs_idx = - derive_index ~product_syms:projections.product_iterators - ~projection:projections.project_lhs - in - let rhs1_idx = - derive_index ~product_syms:projections.product_iterators - ~projection:projections.project_rhs.(0) - in - let rhs2_idx = - derive_index ~product_syms:projections.product_iterators - ~projection:projections.project_rhs.(1) - in - let basecase rev_iters = - let product = Array.of_list_rev_map rev_iters ~f:(fun s -> Indexing.Iterator s) in - let rhs1_idcs = rhs1_idx ~product in - let rhs2_idcs = rhs2_idx ~product in - let lhs_idcs = lhs_idx ~product in - let open Low_level in - let lhs_ll = get (Node lhs) lhs_idcs in - let rhs1_ll = get rhs1 rhs1_idcs in - let rhs2_ll = get rhs2 rhs2_idcs in - let rhs2 = binop ~op ~rhs1:rhs1_ll ~rhs2:rhs2_ll in - if is_total ~initialize_neutral ~projections then set lhs lhs_idcs rhs2 - else set lhs lhs_idcs @@ binop ~op:accum ~rhs1:lhs_ll ~rhs2 - in - let rec for_loop rev_iters = function - | [] -> basecase rev_iters - | d :: product -> - let index = Indexing.get_symbol () in - For_loop - { - index; - from_ = 0; - to_ = d - 1; - body = for_loop (index :: rev_iters) product; - trace_it = true; - } - in - let for_loops = - try for_loop [] (Array.to_list projections.product_space) - with e -> - [%log "projections=", (projections : projections)]; - raise e - in - if initialize_neutral && not (is_total ~initialize_neutral ~projections) then - let dims = lazy projections.lhs_dims in - let fetch_op = Constant (Ops.neutral_elem accum) in - Low_level.Seq (loop (Fetch { array = lhs; fetch_op; dims }), for_loops) - else for_loops + loop_accum ~initialize_neutral ~accum ~op:(Ops.Binop op) ~lhs ~rhses:[| rhs1; rhs2 |] + projections | Accum_unop { initialize_neutral; accum; op; lhs; rhs; projections } -> - let projections = Lazy.force projections in - let lhs_idx = - derive_index ~product_syms:projections.product_iterators - ~projection:projections.project_lhs - in - let rhs_idx = - derive_index ~product_syms:projections.product_iterators - ~projection:projections.project_rhs.(0) - in - let basecase rev_iters = - let product = Array.of_list_rev_map rev_iters ~f:(fun s -> Indexing.Iterator s) in - let lhs_idcs = lhs_idx ~product in - let open Low_level in - let lhs_ll = get (Node lhs) lhs_idcs in - let rhs_ll = get rhs @@ rhs_idx ~product in - let rhs2 = unop ~op ~rhs:rhs_ll in - if is_total ~initialize_neutral ~projections then set lhs lhs_idcs rhs2 - else set lhs lhs_idcs @@ binop ~op:accum ~rhs1:lhs_ll ~rhs2 - in - let rec for_loop rev_iters = function - | [] -> basecase rev_iters - | d :: product -> - let index = Indexing.get_symbol () in - For_loop - { - index; - from_ = 0; - to_ = d - 1; - body = for_loop (index :: rev_iters) product; - trace_it = true; - } - in - let for_loops = for_loop [] (Array.to_list projections.product_space) in - if initialize_neutral && not (is_total ~initialize_neutral ~projections) then - let dims = lazy projections.lhs_dims in - let fetch_op = Constant (Ops.neutral_elem accum) in - Low_level.Seq (loop (Fetch { array = lhs; fetch_op; dims }), for_loops) - else for_loops + loop_accum ~initialize_neutral ~accum ~op:(Ops.Unop op) ~lhs ~rhses:[| rhs |] projections | Noop -> Low_level.Noop | Block_comment (s, c) -> Low_level.unflat_lines [ Comment s; loop c; Comment "end" ] | Seq (c1, c2) -> @@ -251,7 +225,6 @@ let%diagn2_sexp to_low_level code = Low_level.loop_over_dims (Lazy.force dims) ~body:(fun idcs -> set array idcs @@ Get_global (global, Some idcs)) in - loop code let flatten c = @@ -259,7 +232,7 @@ let flatten c = | Noop -> [] | Seq (c1, c2) -> loop c1 @ loop c2 | Block_comment (s, c) -> Block_comment (s, Noop) :: loop c - | (Accum_binop _ | Accum_unop _ | Fetch _) as c -> [ c ] + | (Accum_ternop _ | Accum_binop _ | Accum_unop _ | Fetch _) as c -> [ c ] in loop c @@ -286,6 +259,9 @@ let get_ident_within_code ?no_dots c = loop c1; loop c2 | Block_comment (_, c) -> loop c + | Accum_ternop + { initialize_neutral = _; accum = _; op = _; lhs; rhs1; rhs2; rhs3; projections = _ } -> + List.iter ~f:visit [ lhs; tn rhs1; tn rhs2; tn rhs3 ] | Accum_binop { initialize_neutral = _; accum = _; op = _; lhs; rhs1; rhs2; projections = _ } -> List.iter ~f:visit [ lhs; tn rhs1; tn rhs2 ] | Accum_unop { initialize_neutral = _; accum = _; op = _; lhs; rhs; projections = _ } -> @@ -331,6 +307,16 @@ let fprint_hum ?name ?static_indices () ppf c = | Block_comment (s, c) -> fprintf ppf "# \"%s\";@ " s; loop c + | Accum_ternop { initialize_neutral; accum; op; lhs; rhs1; rhs2; rhs3; projections } -> + let proj_spec = + if Lazy.is_val projections then (Lazy.force projections).debug_info.spec + else "" + in + (* Uncurried syntax for ternary operations. *) + fprintf ppf "%s %s %s(%s, %s, %s)%s;@ " (ident lhs) + (Ops.assign_op_cd_syntax ~initialize_neutral accum) + (Ops.ternop_cd_syntax op) (buffer_ident rhs1) (buffer_ident rhs2) (buffer_ident rhs3) + (if not (String.equal proj_spec ".") then " ~logic:\"" ^ proj_spec ^ "\"" else "") | Accum_binop { initialize_neutral; accum; op; lhs; rhs1; rhs2; projections } -> let proj_spec = if Lazy.is_val projections then (Lazy.force projections).debug_info.spec diff --git a/arrayjit/lib/c_syntax.ml b/arrayjit/lib/c_syntax.ml index 19b43dfd..e4607412 100644 --- a/arrayjit/lib/c_syntax.ml +++ b/arrayjit/lib/c_syntax.ml @@ -23,6 +23,7 @@ module C_syntax (B : sig val kernel_prep_line : string val includes : string list val typ_of_prec : Ops.prec -> string + val ternop_syntax : Ops.prec -> Ops.ternop -> string * string * string * string val binop_syntax : Ops.prec -> Ops.binop -> string * string * string val unop_syntax : Ops.prec -> Ops.unop -> string * string val convert_precision : from:Ops.prec -> to_:Ops.prec -> string * string @@ -166,6 +167,7 @@ struct | Get_local _ | Get_global _ | Get _ | Constant _ | Embed_index _ -> 0 | Binop (Arg1, v1, _v2) -> pp_top_locals ppf v1 | Binop (Arg2, _v1, v2) -> pp_top_locals ppf v2 + | Ternop (_, v1, v2, v3) -> pp_top_locals ppf v1 + pp_top_locals ppf v2 + pp_top_locals ppf v3 | Binop (_, v1, v2) -> pp_top_locals ppf v1 + pp_top_locals ppf v2 | Unop (_, v) -> pp_top_locals ppf v and pp_float (prec : Ops.prec) ppf value = @@ -203,6 +205,10 @@ struct fprintf ppf "%s%a%s" prefix pp_axis_index idx postfix | Binop (Arg1, v1, _v2) -> loop ppf v1 | Binop (Arg2, _v1, v2) -> loop ppf v2 + | Ternop (op, v1, v2, v3) -> + let prefix, comma1, comma2, postfix = B.ternop_syntax prec op in + fprintf ppf "@[<1>%s%a%s@ %a%s@ %a@]%s" prefix loop v1 comma1 loop v2 comma2 loop v3 + postfix | Binop (op, v1, v2) -> let prefix, infix, postfix = B.binop_syntax prec op in fprintf ppf "@[<1>%s%a%s@ %a@]%s" prefix loop v1 infix loop v2 postfix @@ -250,6 +256,13 @@ struct | Embed_index (Iterator s) -> (Indexing.symbol_ident s, []) | Binop (Arg1, v1, _v2) -> loop v1 | Binop (Arg2, _v1, v2) -> loop v2 + | Ternop (op, v1, v2, v3) -> + let prefix, comma1, comma2, postfix = B.ternop_syntax prec op in + let v1, idcs1 = loop v1 in + let v2, idcs2 = loop v2 in + let v3, idcs3 = loop v3 in + ( String.concat [ prefix; v1; comma1; " "; v2; comma2; " "; v3; postfix ], + idcs1 @ idcs2 @ idcs3 ) | Binop (op, v1, v2) -> let prefix, infix, postfix = B.binop_syntax prec op in let v1, idcs1 = loop v1 in diff --git a/arrayjit/lib/cc_backend.ml b/arrayjit/lib/cc_backend.ml index 67748802..02f6be17 100644 --- a/arrayjit/lib/cc_backend.ml +++ b/arrayjit/lib/cc_backend.ml @@ -84,6 +84,7 @@ struct let kernel_prep_line = "" let includes = [ ""; ""; ""; "" ] let typ_of_prec = Ops.c_typ_of_prec + let ternop_syntax = Ops.ternop_c_syntax let binop_syntax = Ops.binop_c_syntax let unop_syntax = Ops.unop_c_syntax let convert_precision = Ops.c_convert_precision diff --git a/arrayjit/lib/gcc_backend.gccjit.ml b/arrayjit/lib/gcc_backend.gccjit.ml index b3a00031..d415bcbd 100644 --- a/arrayjit/lib/gcc_backend.gccjit.ml +++ b/arrayjit/lib/gcc_backend.gccjit.ml @@ -209,6 +209,18 @@ let debug_log_index ctx log_functions = Block.eval block @@ RValue.call ctx ff [ lf ] | _ -> fun _block _i _index -> () +let assign_op_c_syntax = function + | Ops.Arg1 -> invalid_arg "Gcc_backend.assign_op_c_syntax: Arg1 is not a C assignment operator" + | Arg2 -> "=" + | Add -> "+=" + | Sub -> "-=" + | Mul -> "*=" + | Div -> "/=" + | Mod -> "%=" + (* | Shl -> "<<=" *) + (* | Shr -> ">>=" *) + | _ -> invalid_arg "Gcc_backend.assign_op_c_syntax: not a C assignment operator" + let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node; _ } func initial_block (body : Low_level.t) = let open Gccjit in @@ -335,7 +347,7 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node; @@ lf :: RValue.string_literal ctx [%string - {|%{node_debug_name get_ident node}[%d]{=%g} %{Ops.assign_op_c_syntax accum_op} %g = %{v_format} + {|%{node_debug_name get_ident node}[%d]{=%g} %{assign_op_c_syntax accum_op} %g = %{v_format} |}] :: (to_d @@ RValue.lvalue @@ LValue.access_array (Lazy.force node.ptr) offset) :: offset :: to_d value :: v_fillers; diff --git a/arrayjit/lib/low_level.ml b/arrayjit/lib/low_level.ml index ae854629..1f61e98e 100644 --- a/arrayjit/lib/low_level.ml +++ b/arrayjit/lib/low_level.ml @@ -45,16 +45,22 @@ and float_t = | Get_local of scope_id | Get_global of Ops.global_identifier * Indexing.axis_index array option | Get of Tn.t * Indexing.axis_index array + | Ternop of Ops.ternop * float_t * float_t * float_t | Binop of Ops.binop * float_t * float_t | Unop of Ops.unop * float_t | Constant of float | Embed_index of Indexing.axis_index [@@deriving sexp_of, equal, compare] -let binop ~op ~rhs1 ~rhs2 = - match op with Ops.Arg1 -> rhs1 | Arg2 -> rhs2 | _ -> Binop (op, rhs1, rhs2) - -let unop ~op ~rhs = match op with Ops.Identity -> rhs | _ -> Unop (op, rhs) +let apply_op op args = + match (op, args) with + | Ops.Binop Ops.Arg1, [| rhs1; _ |] -> rhs1 + | Binop Arg2, [| _; rhs2 |] -> rhs2 + | Unop Identity, [| rhs |] -> rhs + | Ternop op, [| rhs1; rhs2; rhs3 |] -> Ternop (op, rhs1, rhs2, rhs3) + | Binop op, [| rhs1; rhs2 |] -> Binop (op, rhs1, rhs2) + | Unop op, [| rhs |] -> Unop (op, rhs) + | _ -> invalid_arg "Low_level.op: invalid number of arguments" let rec flat_lines ts = List.concat_map ts ~f:(function Seq (t1, t2) -> flat_lines [ t1; t2 ] | t -> [ t ]) @@ -134,6 +140,7 @@ let is_constexpr_comp traced_store llv = | Get (tn, _idcs) -> let traced = get_node traced_store tn in traced.is_scalar_constexpr + | Ternop (_, v1, v2, v3) -> loop v1 && loop v2 && loop v3 | Binop (_, v1, v2) -> loop v1 && loop v2 | Unop (_, v) -> loop v | Constant _ -> true @@ -212,6 +219,10 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc = | Embed_index _ -> () | Binop (Arg1, llv1, _llv2) -> loop llv1 | Binop (Arg2, _llv1, llv2) -> loop llv2 + | Ternop (_, llv1, llv2, llv3) -> + loop llv1; + loop llv2; + loop llv3 | Binop (_, llv1, llv2) -> loop llv1; loop llv2 @@ -325,6 +336,10 @@ let%diagn2_sexp check_and_store_virtual traced static_indices top_llc = [%log2 "Inlining candidate has an escaping variable", (s : Indexing.symbol), (top_llc : t)]; raise @@ Non_virtual 10) + | Ternop (_, llv1, llv2, llv3) -> + loop_float ~env_dom llv1; + loop_float ~env_dom llv2; + loop_float ~env_dom llv3 | Binop (_, llv1, llv2) -> loop_float ~env_dom llv1; loop_float ~env_dom llv2 @@ -404,6 +419,8 @@ let inline_computation ~id traced static_indices call_args = | Get_local _ -> llv | Get_global _ -> llv | Embed_index idx -> Embed_index (subst env idx) + | Ternop (op, llv1, llv2, llv3) -> + Ternop (op, loop_float env llv1, loop_float env llv2, loop_float env llv3) | Binop (op, llv1, llv2) -> Binop (op, loop_float env llv1, loop_float env llv2) | Unop (op, llv) -> Unop (op, loop_float env llv) in @@ -478,6 +495,12 @@ let virtual_llc traced_store reverse_node_map static_indices (llc : t) : t = | Get_local _ -> llv | Get_global _ -> llv | Embed_index _ -> llv + | Ternop (op, llv1, llv2, llv3) -> + Ternop + ( op, + loop_float ~process_for llv1, + loop_float ~process_for llv2, + loop_float ~process_for llv3 ) | Binop (op, llv1, llv2) -> Binop (op, loop_float ~process_for llv1, loop_float ~process_for llv2) | Unop (op, llv) -> Unop (op, loop_float ~process_for llv) @@ -557,6 +580,7 @@ let cleanup_virtual_llc reverse_node_map ~static_indices (llc : t) : t = | Embed_index (Iterator s) -> assert (Set.mem env_dom s); llv + | Ternop (op, llv1, llv2, llv3) -> Ternop (op, loop llv1, loop llv2, loop llv3) | Binop (op, llv1, llv2) -> Binop (op, loop llv1, loop llv2) | Unop (op, llv) -> Unop (op, loop llv) in @@ -578,6 +602,7 @@ let rec substitute_float ~var ~value llv = | Get_local _ -> llv | Get_global _ -> llv | Embed_index _ -> llv + | Ternop (op, llv1, llv2, llv3) -> Ternop (op, loop_float llv1, loop_float llv2, loop_float llv3) | Binop (op, llv1, llv2) -> Binop (op, loop_float llv1, loop_float llv2) | Unop (op, llv) -> Unop (op, loop_float llv) @@ -640,6 +665,13 @@ let simplify_llc llc = | Get_global _ -> llv | Embed_index (Fixed_idx i) -> Constant (Float.of_int i) | Embed_index (Iterator _) -> llv + | Ternop (op, llv1, llv2, llv3) -> + (* FIXME: NOT IMPLEMENTED YET *) + let v1 = loop_float llv1 in + let v2 = loop_float llv2 in + let v3 = loop_float llv3 in + let result = Ternop (op, v1, v2, v3) in + if equal_float_t llv1 v1 && equal_float_t llv2 v2 then result else loop_float result | Binop (Arg1, llv1, _) -> loop_float llv1 | Binop (Arg2, _, llv2) -> loop_float llv2 | Binop (op, Constant c1, Constant c2) -> Constant (Ops.interpret_binop op c1 c2) @@ -725,6 +757,10 @@ let simplify_llc llc = match llv with | Constant c -> check_constant tn c | Local_scope { body; _ } -> check_proc body + | Ternop (_, v1, v2, v3) -> + loop v1; + loop v2; + loop v3 | Binop (_, v1, v2) -> loop v1; loop v2 @@ -828,6 +864,10 @@ let get_ident_within_code ?no_dots llcs = loop body | Get_global (_, _) -> () | Get (la, _) -> visit la + | Ternop (_, f1, f2, f3) -> + loop_float f1; + loop_float f2; + loop_float f3 | Binop (_, f1, f2) -> loop_float f1; loop_float f2 @@ -893,6 +933,10 @@ let fprint_cstyle ?name ?static_indices () ppf llc = | Get (tn, idcs) -> fprintf ppf "@[<2>%a[@,%a]@]" pp_ident tn pp_indices idcs | Constant c -> fprintf ppf "%.16g" c | Embed_index idx -> pp_axis_index ppf idx + | Ternop (op, v1, v2, v3) -> + let prefix, comma1, comma2, postfix = Ops.ternop_c_syntax prec op in + fprintf ppf "@[<1>%s%a%s@ %a%s@ %a@]%s" prefix (pp_float prec) v1 comma1 (pp_float prec) v2 + comma2 (pp_float prec) v3 postfix | Binop (Arg1, v1, _v2) -> pp_float prec ppf v1 | Binop (Arg2, _v1, v2) -> pp_float prec ppf v2 | Binop (op, v1, v2) -> @@ -948,6 +992,9 @@ let fprint_hum ?name ?static_indices () ppf llc = | Get (tn, idcs) -> fprintf ppf "@[<2>%a[@,%a]@]" pp_ident tn pp_indices idcs | Constant c -> fprintf ppf "%.16g" c | Embed_index idx -> pp_axis_index ppf idx + | Ternop (op, v1, v2, v3) -> + let prefix = Ops.ternop_cd_syntax op in + fprintf ppf "@[<1>%s(%a,@ %a,@ %a@])" prefix pp_float v1 pp_float v2 pp_float v3 | Binop (Arg1, v1, _v2) -> pp_float ppf v1 | Binop (Arg2, _v1, v2) -> pp_float ppf v2 | Binop (op, v1, v2) -> diff --git a/arrayjit/lib/low_level.mli b/arrayjit/lib/low_level.mli index 1b2cc4b4..b5936ea9 100644 --- a/arrayjit/lib/low_level.mli +++ b/arrayjit/lib/low_level.mli @@ -31,14 +31,14 @@ and float_t = | Get_local of scope_id | Get_global of Ops.global_identifier * Indexing.axis_index array option | Get of Tnode.t * Indexing.axis_index array + | Ternop of Ops.ternop * float_t * float_t * float_t | Binop of Ops.binop * float_t * float_t | Unop of Ops.unop * float_t | Constant of float | Embed_index of Indexing.axis_index [@@deriving sexp_of, equal, compare] -val binop : op:Ops.binop -> rhs1:float_t -> rhs2:float_t -> float_t -val unop : op:Ops.unop -> rhs:float_t -> float_t +val apply_op : Ops.op -> float_t array -> float_t val flat_lines : t list -> t list val unflat_lines : t list -> t val loop_over_dims : int array -> body:(Indexing.axis_index array -> t) -> t diff --git a/arrayjit/lib/ops.ml b/arrayjit/lib/ops.ml index 6d4418fa..243ae1eb 100644 --- a/arrayjit/lib/ops.ml +++ b/arrayjit/lib/ops.ml @@ -175,6 +175,8 @@ type unop = type ternop = Where (** Where(a,b,c): if a then b else c *) | FMA (** FMA(a,b,c): (a * b) + c *) [@@deriving sexp, compare, equal] +type op = Ternop of ternop | Binop of binop | Unop of unop [@@deriving sexp, compare, equal] + (** Either the left-neutral or right-neutral element of the operation. Unspecified if the operation does not have a neutral element. *) let neutral_elem = function @@ -230,11 +232,12 @@ let interpret_unop op v = | Neg -> ~-.v | Tanh_approx -> tanh v -let is_binop_infix _ = true +let interpret_ternop op v1 v2 v3 = + let open Float in + match op with Where -> if v1 <> 0. then v2 else v3 | FMA -> (v1 * v2) + v3 -let is_binop_nice_infix = function - | Arg1 | Arg2 | Relu_gate | Max | Min -> false - | _ -> true +let is_binop_infix _ = true +let is_binop_nice_infix = function Arg1 | Arg2 | Relu_gate | Max | Min -> false | _ -> true let binop_cd_syntax = function | Arg1 -> "-@>" @@ -252,8 +255,8 @@ let binop_cd_syntax = function | Mod -> "%" | Max -> "@^" | Min -> "^^" - (* | Shl -> "lsl" *) - (* | Shr -> "lsr" *) +(* | Shl -> "lsl" *) +(* | Shr -> "lsr" *) let binop_cd_fallback_syntax = function | Arg1 -> "fst" @@ -271,8 +274,8 @@ let binop_cd_fallback_syntax = function | Mod -> "modf" | Max -> "max" | Min -> "min" - (* | Shl -> "shlf" *) - (* | Shr -> "shrf" *) +(* | Shl -> "shlf" *) +(* | Shr -> "shrf" *) let binop_c_syntax prec v = match (v, prec) with @@ -332,18 +335,6 @@ let assign_op_cd_syntax ~initialize_neutral = function | Arg1 | Mod (* | Shl | Shr *) | Cmplt | Cmpne -> invalid_arg "Ops.assign_op_cd_syntax: not an assignment op" -let assign_op_c_syntax = function - | Arg1 -> invalid_arg "Ops.assign_op_c_syntax: Arg1 is not a C assignment operator" - | Arg2 -> "=" - | Add -> "+=" - | Sub -> "-=" - | Mul -> "*=" - | Div -> "/=" - | Mod -> "%=" - (* | Shl -> "<<=" *) - (* | Shr -> ">>=" *) - | _ -> invalid_arg "Ops.assign_op_c_syntax: not a C assignment operator" - (** Note: currently we do not support unary prefix symbols. *) let unop_cd_syntax = function | Identity -> "id" @@ -361,7 +352,7 @@ let unop_cd_syntax = function | Neg -> "neg" | Tanh_approx -> "tanh" -let unop_c_syntax prec v = +let unop_c_syntax prec op = let fmax () = (* See: https://en.cppreference.com/w/c/numeric/math/fmax option (4) *) match prec with @@ -374,7 +365,7 @@ let unop_c_syntax prec v = | Double_prec _ | Byte_prec _ -> "fmax" | _ -> "fmaxf" in - match (v, prec) with + match (op, prec) with | Identity, _ -> ("", "") | Relu, Byte_prec _ -> ("fmax(0, ", ")") | Relu, _ -> (fmax () ^ "(0.0, ", ")") @@ -406,6 +397,15 @@ let unop_c_syntax prec v = invalid_arg "Ops.unop_c_syntax: Tanh_approx not supported for byte/integer precisions" | Tanh_approx, _ -> ("tanhf(", ")") +let ternop_cd_syntax = function Where -> "where" | FMA -> "fma" + +let ternop_c_syntax prec op = + match (op, prec) with + | Where, Byte_prec _ -> ("((", ") != 0 ? (", ") : (", "))") + | Where, _ -> ("((", ") != 0.0 ? (", ") : (", "))") + | FMA, (Double_prec _ | Byte_prec _) -> ("fma(", ",", ",", ")") + | FMA, _ -> ("fmaf(", ",", ",", ")") + let c_convert_precision ~from ~to_ = match (from, to_) with | Double_prec _, Double_prec _