Skip to content

Commit

Permalink
Ternary primitive operations, in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Jan 25, 2025
1 parent 2bd4336 commit fde7983
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 120 deletions.
168 changes: 77 additions & 91 deletions arrayjit/lib/assignments.ml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ and t =
| Noop
| Seq of t * t
| Block_comment of string * t (** Same as the given code, with a comment. *)
| Accum_ternop of {
initialize_neutral : bool;
accum : Ops.binop;
op : Ops.ternop;
lhs : Tn.t;
rhs1 : buffer;
rhs2 : buffer;
rhs3 : buffer;
projections : Indexing.projections Lazy.t;
}
| Accum_binop of {
initialize_neutral : bool;
accum : Ops.binop;
Expand Down Expand Up @@ -93,6 +103,8 @@ let%debug3_sexp context_nodes ~(use_host_memory : 'a option) (asgns : t) : Tn.t_
| Accum_unop { lhs; rhs; _ } -> Set.union (one lhs) (of_node rhs)
| Accum_binop { lhs; rhs1; rhs2; _ } ->
Set.union_list (module Tn) [ one lhs; of_node rhs1; of_node rhs2 ]
| Accum_ternop { lhs; rhs1; rhs2; rhs3; _ } ->
Set.union_list (module Tn) [ one lhs; of_node rhs1; of_node rhs2; of_node rhs3 ]
| Fetch { array; _ } -> one array
in
loop asgns
Expand Down Expand Up @@ -139,98 +151,60 @@ let%diagn2_sexp to_low_level code =
assert (Array.length idcs = Array.length (Lazy.force tn.Tn.dims));
Low_level.Set { tn; idcs; llv; debug = "" }
in
let rec loop code =
let rec loop_accum ~initialize_neutral ~accum ~op ~lhs ~rhses projections =
let projections = Lazy.force projections in
let lhs_idx =
derive_index ~product_syms:projections.product_iterators ~projection:projections.project_lhs
in
let rhs_idcs =
Array.map projections.project_rhs ~f:(fun projection ->
derive_index ~product_syms:projections.product_iterators ~projection)
in
let basecase rev_iters =
let product = Array.of_list_rev_map rev_iters ~f:(fun s -> Indexing.Iterator s) in
let rhses_idcs = Array.map rhs_idcs ~f:(fun rhs_idx -> rhs_idx ~product) in
let lhs_idcs = lhs_idx ~product in
let open Low_level in
let lhs_ll = get (Node lhs) lhs_idcs in
let rhses_ll = Array.mapi rhses_idcs ~f:(fun i rhs_idcs -> get rhses.(i) rhs_idcs) in
let rhs2 = apply_op op rhses_ll in
if is_total ~initialize_neutral ~projections then set lhs lhs_idcs rhs2
else set lhs lhs_idcs @@ apply_op (Ops.Binop accum) [| lhs_ll; rhs2 |]
in
let rec for_loop rev_iters = function
| [] -> basecase rev_iters
| d :: product ->
let index = Indexing.get_symbol () in
For_loop
{
index;
from_ = 0;
to_ = d - 1;
body = for_loop (index :: rev_iters) product;
trace_it = true;
}
in
let for_loops =
try for_loop [] (Array.to_list projections.product_space)
with e ->
[%log "projections=", (projections : projections)];
raise e
in
if initialize_neutral && not (is_total ~initialize_neutral ~projections) then
let dims = lazy projections.lhs_dims in
let fetch_op = Constant (Ops.neutral_elem accum) in
Low_level.Seq (loop (Fetch { array = lhs; fetch_op; dims }), for_loops)
else for_loops
and loop code =
match code with
| Accum_ternop { initialize_neutral; accum; op; lhs; rhs1; rhs2; rhs3; projections } ->
loop_accum ~initialize_neutral ~accum ~op:(Ops.Ternop op) ~lhs ~rhses:[| rhs1; rhs2; rhs3 |]
projections
| Accum_binop { initialize_neutral; accum; op; lhs; rhs1; rhs2; projections } ->
let projections = Lazy.force projections in
let lhs_idx =
derive_index ~product_syms:projections.product_iterators
~projection:projections.project_lhs
in
let rhs1_idx =
derive_index ~product_syms:projections.product_iterators
~projection:projections.project_rhs.(0)
in
let rhs2_idx =
derive_index ~product_syms:projections.product_iterators
~projection:projections.project_rhs.(1)
in
let basecase rev_iters =
let product = Array.of_list_rev_map rev_iters ~f:(fun s -> Indexing.Iterator s) in
let rhs1_idcs = rhs1_idx ~product in
let rhs2_idcs = rhs2_idx ~product in
let lhs_idcs = lhs_idx ~product in
let open Low_level in
let lhs_ll = get (Node lhs) lhs_idcs in
let rhs1_ll = get rhs1 rhs1_idcs in
let rhs2_ll = get rhs2 rhs2_idcs in
let rhs2 = binop ~op ~rhs1:rhs1_ll ~rhs2:rhs2_ll in
if is_total ~initialize_neutral ~projections then set lhs lhs_idcs rhs2
else set lhs lhs_idcs @@ binop ~op:accum ~rhs1:lhs_ll ~rhs2
in
let rec for_loop rev_iters = function
| [] -> basecase rev_iters
| d :: product ->
let index = Indexing.get_symbol () in
For_loop
{
index;
from_ = 0;
to_ = d - 1;
body = for_loop (index :: rev_iters) product;
trace_it = true;
}
in
let for_loops =
try for_loop [] (Array.to_list projections.product_space)
with e ->
[%log "projections=", (projections : projections)];
raise e
in
if initialize_neutral && not (is_total ~initialize_neutral ~projections) then
let dims = lazy projections.lhs_dims in
let fetch_op = Constant (Ops.neutral_elem accum) in
Low_level.Seq (loop (Fetch { array = lhs; fetch_op; dims }), for_loops)
else for_loops
loop_accum ~initialize_neutral ~accum ~op:(Ops.Binop op) ~lhs ~rhses:[| rhs1; rhs2 |]
projections
| Accum_unop { initialize_neutral; accum; op; lhs; rhs; projections } ->
let projections = Lazy.force projections in
let lhs_idx =
derive_index ~product_syms:projections.product_iterators
~projection:projections.project_lhs
in
let rhs_idx =
derive_index ~product_syms:projections.product_iterators
~projection:projections.project_rhs.(0)
in
let basecase rev_iters =
let product = Array.of_list_rev_map rev_iters ~f:(fun s -> Indexing.Iterator s) in
let lhs_idcs = lhs_idx ~product in
let open Low_level in
let lhs_ll = get (Node lhs) lhs_idcs in
let rhs_ll = get rhs @@ rhs_idx ~product in
let rhs2 = unop ~op ~rhs:rhs_ll in
if is_total ~initialize_neutral ~projections then set lhs lhs_idcs rhs2
else set lhs lhs_idcs @@ binop ~op:accum ~rhs1:lhs_ll ~rhs2
in
let rec for_loop rev_iters = function
| [] -> basecase rev_iters
| d :: product ->
let index = Indexing.get_symbol () in
For_loop
{
index;
from_ = 0;
to_ = d - 1;
body = for_loop (index :: rev_iters) product;
trace_it = true;
}
in
let for_loops = for_loop [] (Array.to_list projections.product_space) in
if initialize_neutral && not (is_total ~initialize_neutral ~projections) then
let dims = lazy projections.lhs_dims in
let fetch_op = Constant (Ops.neutral_elem accum) in
Low_level.Seq (loop (Fetch { array = lhs; fetch_op; dims }), for_loops)
else for_loops
loop_accum ~initialize_neutral ~accum ~op:(Ops.Unop op) ~lhs ~rhses:[| rhs |] projections
| Noop -> Low_level.Noop
| Block_comment (s, c) -> Low_level.unflat_lines [ Comment s; loop c; Comment "end" ]
| Seq (c1, c2) ->
Expand All @@ -251,15 +225,14 @@ let%diagn2_sexp to_low_level code =
Low_level.loop_over_dims (Lazy.force dims) ~body:(fun idcs ->
set array idcs @@ Get_global (global, Some idcs))
in

loop code

let flatten c =
let rec loop = function
| Noop -> []
| Seq (c1, c2) -> loop c1 @ loop c2
| Block_comment (s, c) -> Block_comment (s, Noop) :: loop c
| (Accum_binop _ | Accum_unop _ | Fetch _) as c -> [ c ]
| (Accum_ternop _ | Accum_binop _ | Accum_unop _ | Fetch _) as c -> [ c ]
in
loop c

Expand All @@ -286,6 +259,9 @@ let get_ident_within_code ?no_dots c =
loop c1;
loop c2
| Block_comment (_, c) -> loop c
| Accum_ternop
{ initialize_neutral = _; accum = _; op = _; lhs; rhs1; rhs2; rhs3; projections = _ } ->
List.iter ~f:visit [ lhs; tn rhs1; tn rhs2; tn rhs3 ]
| Accum_binop { initialize_neutral = _; accum = _; op = _; lhs; rhs1; rhs2; projections = _ } ->
List.iter ~f:visit [ lhs; tn rhs1; tn rhs2 ]
| Accum_unop { initialize_neutral = _; accum = _; op = _; lhs; rhs; projections = _ } ->
Expand Down Expand Up @@ -331,6 +307,16 @@ let fprint_hum ?name ?static_indices () ppf c =
| Block_comment (s, c) ->
fprintf ppf "# \"%s\";@ " s;
loop c
| Accum_ternop { initialize_neutral; accum; op; lhs; rhs1; rhs2; rhs3; projections } ->
let proj_spec =
if Lazy.is_val projections then (Lazy.force projections).debug_info.spec
else "<not-in-yet>"
in
(* Uncurried syntax for ternary operations. *)
fprintf ppf "%s %s %s(%s, %s, %s)%s;@ " (ident lhs)
(Ops.assign_op_cd_syntax ~initialize_neutral accum)
(Ops.ternop_cd_syntax op) (buffer_ident rhs1) (buffer_ident rhs2) (buffer_ident rhs3)
(if not (String.equal proj_spec ".") then " ~logic:\"" ^ proj_spec ^ "\"" else "")
| Accum_binop { initialize_neutral; accum; op; lhs; rhs1; rhs2; projections } ->
let proj_spec =
if Lazy.is_val projections then (Lazy.force projections).debug_info.spec
Expand Down
13 changes: 13 additions & 0 deletions arrayjit/lib/c_syntax.ml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ module C_syntax (B : sig
val kernel_prep_line : string
val includes : string list
val typ_of_prec : Ops.prec -> string
val ternop_syntax : Ops.prec -> Ops.ternop -> string * string * string * string
val binop_syntax : Ops.prec -> Ops.binop -> string * string * string
val unop_syntax : Ops.prec -> Ops.unop -> string * string
val convert_precision : from:Ops.prec -> to_:Ops.prec -> string * string
Expand Down Expand Up @@ -166,6 +167,7 @@ struct
| Get_local _ | Get_global _ | Get _ | Constant _ | Embed_index _ -> 0
| Binop (Arg1, v1, _v2) -> pp_top_locals ppf v1
| Binop (Arg2, _v1, v2) -> pp_top_locals ppf v2
| Ternop (_, v1, v2, v3) -> pp_top_locals ppf v1 + pp_top_locals ppf v2 + pp_top_locals ppf v3
| Binop (_, v1, v2) -> pp_top_locals ppf v1 + pp_top_locals ppf v2
| Unop (_, v) -> pp_top_locals ppf v
and pp_float (prec : Ops.prec) ppf value =
Expand Down Expand Up @@ -203,6 +205,10 @@ struct
fprintf ppf "%s%a%s" prefix pp_axis_index idx postfix
| Binop (Arg1, v1, _v2) -> loop ppf v1
| Binop (Arg2, _v1, v2) -> loop ppf v2
| Ternop (op, v1, v2, v3) ->
let prefix, comma1, comma2, postfix = B.ternop_syntax prec op in
fprintf ppf "@[<1>%s%a%s@ %a%s@ %a@]%s" prefix loop v1 comma1 loop v2 comma2 loop v3
postfix
| Binop (op, v1, v2) ->
let prefix, infix, postfix = B.binop_syntax prec op in
fprintf ppf "@[<1>%s%a%s@ %a@]%s" prefix loop v1 infix loop v2 postfix
Expand Down Expand Up @@ -250,6 +256,13 @@ struct
| Embed_index (Iterator s) -> (Indexing.symbol_ident s, [])
| Binop (Arg1, v1, _v2) -> loop v1
| Binop (Arg2, _v1, v2) -> loop v2
| Ternop (op, v1, v2, v3) ->
let prefix, comma1, comma2, postfix = B.ternop_syntax prec op in
let v1, idcs1 = loop v1 in
let v2, idcs2 = loop v2 in
let v3, idcs3 = loop v3 in
( String.concat [ prefix; v1; comma1; " "; v2; comma2; " "; v3; postfix ],
idcs1 @ idcs2 @ idcs3 )
| Binop (op, v1, v2) ->
let prefix, infix, postfix = B.binop_syntax prec op in
let v1, idcs1 = loop v1 in
Expand Down
1 change: 1 addition & 0 deletions arrayjit/lib/cc_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ struct
let kernel_prep_line = ""
let includes = [ "<stdio.h>"; "<stdlib.h>"; "<string.h>"; "<math.h>" ]
let typ_of_prec = Ops.c_typ_of_prec
let ternop_syntax = Ops.ternop_c_syntax
let binop_syntax = Ops.binop_c_syntax
let unop_syntax = Ops.unop_c_syntax
let convert_precision = Ops.c_convert_precision
Expand Down
14 changes: 13 additions & 1 deletion arrayjit/lib/gcc_backend.gccjit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,18 @@ let debug_log_index ctx log_functions =
Block.eval block @@ RValue.call ctx ff [ lf ]
| _ -> fun _block _i _index -> ()

let assign_op_c_syntax = function
| Ops.Arg1 -> invalid_arg "Gcc_backend.assign_op_c_syntax: Arg1 is not a C assignment operator"
| Arg2 -> "="
| Add -> "+="
| Sub -> "-="
| Mul -> "*="
| Div -> "/="
| Mod -> "%="
(* | Shl -> "<<=" *)
(* | Shr -> ">>=" *)
| _ -> invalid_arg "Gcc_backend.assign_op_c_syntax: not a C assignment operator"

let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node; _ } func
initial_block (body : Low_level.t) =
let open Gccjit in
Expand Down Expand Up @@ -335,7 +347,7 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node;
@@ lf
:: RValue.string_literal ctx
[%string
{|%{node_debug_name get_ident node}[%d]{=%g} %{Ops.assign_op_c_syntax accum_op} %g = %{v_format}
{|%{node_debug_name get_ident node}[%d]{=%g} %{assign_op_c_syntax accum_op} %g = %{v_format}
|}]
:: (to_d @@ RValue.lvalue @@ LValue.access_array (Lazy.force node.ptr) offset)
:: offset :: to_d value :: v_fillers;
Expand Down
Loading

0 comments on commit fde7983

Please sign in to comment.