Skip to content

Commit 0cc8fd3

Browse files
authored
Merge pull request #1004 from SteveBronder/feature/add-promoter
Add promoter expression
2 parents c58fcc1 + af2d260 commit 0cc8fd3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2939
-1922
lines changed

src/analysis_and_optimization/Mem_pattern.ml

+43-37
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ let rec matrix_set Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{type_; _} as meta
1717
if UnsizedType.contains_eigen_type type_ then union_recur exprs
1818
else Set.Poly.empty
1919
| TernaryIf (_, expr2, expr3) -> union_recur [expr2; expr3]
20-
| Indexed (expr, _) -> matrix_set expr
20+
| Indexed (expr, _) | Promotion (expr, _, _) -> matrix_set expr
2121
| EAnd (expr1, expr2) | EOr (expr1, expr2) -> union_recur [expr1; expr2]
2222
else Set.Poly.empty
2323

@@ -45,7 +45,7 @@ let is_nonzero_subset ~set ~subset =
4545
&& not (Set.Poly.is_empty subset)
4646

4747
(**
48-
* Check an expression to count how many times we see a single index.
48+
* Check an expression to count how many times we see a single index.
4949
* @param acc An accumulator from previous folds of multiple expressions.
5050
* @param pattern The expression patterns to match against
5151
*)
@@ -62,6 +62,7 @@ let rec count_single_idx_exprs (acc : int) Expr.Fixed.{pattern; _} : int =
6262
acc
6363
+ count_single_idx_exprs 0 idx_expr
6464
+ List.fold_left ~init:0 ~f:count_single_idx indexed
65+
| Promotion (expr, _, _) -> count_single_idx_exprs acc expr
6566
| EAnd (lhs, rhs) ->
6667
acc + count_single_idx_exprs 0 lhs + count_single_idx_exprs 0 rhs
6768
| EOr (lhs, rhs) ->
@@ -72,9 +73,9 @@ let rec count_single_idx_exprs (acc : int) Expr.Fixed.{pattern; _} : int =
7273
(**
7374
* Check an Index to count how many times we see a single index.
7475
* @param acc An accumulator from previous folds of multiple expressions.
75-
* @param idx An Index to match. For Single types this adds 1 to the
76-
* acc. For Upfrom and MultiIndex types we check the inner expression
77-
* for a Single index. All and Between cannot be Single cell access
76+
* @param idx An Index to match. For Single types this adds 1 to the
77+
* acc. For Upfrom and MultiIndex types we check the inner expression
78+
* for a Single index. All and Between cannot be Single cell access
7879
* and so pass acc along.
7980
*)
8081
and count_single_idx (acc : int) (idx : Expr.Typed.Meta.t Expr.Fixed.t Index.t)
@@ -84,13 +85,13 @@ and count_single_idx (acc : int) (idx : Expr.Typed.Meta.t Expr.Fixed.t Index.t)
8485
| Single _ -> acc + 1
8586

8687
(**
87-
* Find indices on Matrix and Vector types that perform single
88+
* Find indices on Matrix and Vector types that perform single
8889
* cell access. Returns true if it finds
8990
* a vector, row vector, matrix, or matrix with single cell access
90-
* as well as an array of any of the above that is accessing the
91+
* as well as an array of any of the above that is accessing the
9192
* inner matrix types cell.
9293
* @param ut An UnsizedType to match against.
93-
* @param index This list is checked for Single cell access
94+
* @param index This list is checked for Single cell access
9495
* either at the top level or within the `Index` types of the list.
9596
*)
9697
let rec is_uni_eigen_loop_indexing in_loop (ut : UnsizedType.t)
@@ -122,7 +123,7 @@ let is_fun_soa_supported name exprs =
122123
* see the docs for `query_initial_demotable_funs`.
123124
* @param in_loop a boolean to signify if the expression exists inside
124125
* of a loop. If so, the names of matrix and vector like objects
125-
* will be returned if the matrix or vector is accessed by single
126+
* will be returned if the matrix or vector is accessed by single
126127
* cell indexing.
127128
*)
128129
let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
@@ -147,6 +148,7 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
147148
Set.Poly.union acc index_demotes
148149
| Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) ->
149150
acc
151+
| Promotion (expr, _, _) -> query_expr acc expr
150152
| TernaryIf (predicate, texpr, fexpr) ->
151153
let predicate_demotes = query_expr acc predicate in
152154
Set.Poly.union
@@ -159,18 +161,18 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
159161
Set.Poly.union (query_expr full_lhs_rhs lhs) (query_expr full_lhs_rhs rhs)
160162

161163
(**
162-
* Query a function to detect if it or any of its used
164+
* Query a function to detect if it or any of its used
163165
* expression's objects or expressions should be demoted to AoS.
164166
*
165167
* The logic here demotes the expressions in a function to AoS if
166-
* the function's inner expression returns has a meta type containing a matrix
168+
* the function's inner expression returns has a meta type containing a matrix
167169
* and either of :
168170
* (1) The function is user defined and the UDFs inputs are matrices.
169171
* (2) The Stan math function cannot support AoS
170172
* @param in_loop A boolean to specify the logic of indexing expressions. See
171173
* `query_initial_demotable_expr` for an explanation of the logic.
172-
* @param kind The function type, for StanLib functions we check if the
173-
* function supports SoA and for UserDefined functions we always fail
174+
* @param kind The function type, for StanLib functions we check if the
175+
* function supports SoA and for UserDefined functions we always fail
174176
* and return back all of the names of the objects passed in expressions
175177
* to the UDF.
176178
* exprs The expression list passed to the functions.
@@ -212,7 +214,8 @@ let rec is_any_soa_supported_expr
212214
match pattern with
213215
| FunApp (kind, (exprs : Expr.Typed.Meta.t Expr.Fixed.t list)) ->
214216
is_any_soa_supported_fun_expr kind exprs
215-
| Indexed (expr, (_ : Typed.Meta.t Fixed.t Index.t list)) ->
217+
| Indexed (expr, (_ : Typed.Meta.t Fixed.t Index.t list))
218+
|Promotion (expr, _, _) ->
216219
is_any_soa_supported_expr expr
217220
| Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) ->
218221
true
@@ -248,7 +251,8 @@ let rec is_any_ad_real_data_matrix_expr
248251
match pattern with
249252
| FunApp (kind, (exprs : Expr.Typed.Meta.t Expr.Fixed.t list)) ->
250253
is_any_ad_real_data_matrix_expr_fun kind exprs
251-
| Indexed (expr, _) -> is_any_ad_real_data_matrix_expr expr
254+
| Indexed (expr, _) | Promotion (expr, _, _) ->
255+
is_any_ad_real_data_matrix_expr expr
252256
| Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) ->
253257
false
254258
| TernaryIf (_, texpr, fexpr) ->
@@ -303,15 +307,15 @@ and is_any_ad_real_data_matrix_expr_fun (kind : 'a Fun_kind.t)
303307

304308
(**
305309
* Query to find the initial set of objects in statements that cannot be SoA.
306-
* This is mostly recursive over expressions and statements, with the exception of
310+
* This is mostly recursive over expressions and statements, with the exception of
307311
* functions and Assignments.
308312
*
309313
* For assignments:
310314
* We demote the LHS variable if any of the following are true:
311-
* 1. None of the RHS's functions are able to accept SoA matrices
315+
* 1. None of the RHS's functions are able to accept SoA matrices
312316
* and the rhs is not an internal compiler function.
313317
* 2. A single cell of the LHS is being assigned within a loop.
314-
* 3. The top level expression on the RHS is a combination of only
318+
* 3. The top level expression on the RHS is a combination of only
315319
* data matrices and scalar types. Operations on data matrix and
316320
* scalar values in Stan math will return a AoS matrix. We currently
317321
* have no way to tell Stan math to return a SoA matrix.
@@ -408,11 +412,11 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
408412
(** Look through a statement to see whether the objects used in it need to be
409413
* modified from SoA to AoS. Returns the set of object names that need demoted
410414
* in a statement, if any.
411-
* This function looks at Assignment statements, and returns back the
415+
* This function looks at Assignment statements, and returns back the
412416
* set of top level object names given:
413-
* 1. If the name of the lhs assignee is in the `aos_exits`, all the names
417+
* 1. If the name of the lhs assignee is in the `aos_exits`, all the names
414418
* of the expressions with a type containing a matrix are returned.
415-
* 2. If the names of the rhs objects containing matrix types are in the subset of
419+
* 2. If the names of the rhs objects containing matrix types are in the subset of
416420
* aos_exits.
417421
* @param aos_exits A set of variables that can be demoted.
418422
* @param pattern The Stmt pattern to query.
@@ -437,15 +441,15 @@ let query_demotable_stmt (aos_exits : string Set.Poly.t)
437441

438442
(**
439443
* Modify a function and it's subexpressions from SoA <-> AoS and vice versa.
440-
* This performs demotion for sub expressions recursively. The top level
441-
* expression and it's sub expressions are demoted to SoA if
442-
* 1. The names of the variables in the subexpressions returning
444+
* This performs demotion for sub expressions recursively. The top level
445+
* expression and it's sub expressions are demoted to SoA if
446+
* 1. The names of the variables in the subexpressions returning
443447
* objects holding matrices are all in the modifiable set.
444448
* 2. The function does not support SoA
445449
* 3. The `force` argument is `true`
446-
* @param force_demotion If true, forces an expression and it's sub-expressions
450+
* @param force_demotion If true, forces an expression and it's sub-expressions
447451
* to be AoS.
448-
* @param modifiable_set The set of names that are either demotable
452+
* @param modifiable_set The set of names that are either demotable
449453
* to AoS or promotable to SoA.
450454
* @param kind A `Fun_kind.t`
451455
* @param exprs A list of expressions going into the function.
@@ -474,15 +478,15 @@ let rec modify_kind ?force_demotion:(force = false)
474478
( kind
475479
, List.map ~f:(modify_expr ~force_demotion:force modifiable_set) exprs )
476480

477-
(**
481+
(**
478482
* Modify an expression and it's subexpressions from SoA <-> AoS
479483
* and vice versa. The only real paths in the below is on the
480484
* functions and ternary expressions.
481485
*
482486
* The logic for functions is defined in `modify_kind`.
483487
* `TernaryIf` is forcefully demoted to AoS if the type of the expression
484488
* contains a matrix.
485-
* @param force_demotion If true, forces an expression and it's sub-expressions
489+
* @param force_demotion If true, forces an expression and it's sub-expressions
486490
* to be AoS.
487491
* @param modifiable_set The name of the variables whose
488492
* associated expressions we want to modify.
@@ -519,11 +523,13 @@ and modify_expr_pattern ?force_demotion:(force = false)
519523
, List.map ~f:(Index.map (mod_expr ~force_demotion:force)) indexed )
520524
| EAnd (lhs, rhs) -> EAnd (mod_expr lhs, mod_expr rhs)
521525
| EOr (lhs, rhs) -> EOr (mod_expr lhs, mod_expr rhs)
526+
| Promotion (expr, type_, ad_level) ->
527+
Promotion (mod_expr expr, type_, ad_level)
522528
| Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) ->
523529
pattern
524530

525-
(**
526-
* Given a Set of strings containing the names of objects that can be
531+
(**
532+
* Given a Set of strings containing the names of objects that can be
527533
* modified from AoS <-> SoA and vice versa, modify them within the expression.
528534
* @param mem_pattern The memory pattern to change expressions to.
529535
* @param modifiable_set The name of the variables whose
@@ -536,15 +542,15 @@ and modify_expr ?force_demotion:(force = false)
536542
pattern= modify_expr_pattern ~force_demotion:force modifiable_set pattern }
537543

538544
(**
539-
* Modify statement patterns in the MIR from AoS <-> SoA and vice versa
545+
* Modify statement patterns in the MIR from AoS <-> SoA and vice versa
540546
* For `Decl` and `Assignment`'s reading in parameters, we demote to AoS
541547
* if the `decl_id` (or assign name) is in the modifiable set and
542548
* otherwise promote the statement to `SoA`.
543549
* For general `Assignment` statements, we check if the assignee is in
544550
* the demotable set. If so, we force demotion of all of the rhs expressions.
545551
* All other statements recurse over their statements and expressions.
546-
*
547-
* @param pattern The statement pattern to modify
552+
*
553+
* @param pattern The statement pattern to modify
548554
* @param modifiable_set The name of the variable we are searching for.
549555
*)
550556
let rec modify_stmt_pattern
@@ -624,11 +630,11 @@ let rec modify_stmt_pattern
624630
| Skip | Break | Continue | Decl _ -> pattern
625631

626632
(**
627-
* Modify statement patterns in the MIR from AoS <-> SoA and vice versa
628-
* @param mem_pattern A mem_pattern to modify expressions to. For the
629-
* given memory pattern, this modifies
633+
* Modify statement patterns in the MIR from AoS <-> SoA and vice versa
634+
* @param mem_pattern A mem_pattern to modify expressions to. For the
635+
* given memory pattern, this modifies
630636
* statement patterns and expressions to it.
631-
* @param stmt The statement to modify.
637+
* @param stmt The statement to modify.
632638
* @param modifiable_set The name of the variable we are searching for.
633639
*)
634640
and modify_stmt (Stmt.Fixed.{pattern; _} as stmt)

src/analysis_and_optimization/Mir_utils.ml

+6
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ let rec expr_var_set Expr.Fixed.{pattern; meta} =
251251
| TernaryIf (expr1, expr2, expr3) -> union_recur [expr1; expr2; expr3]
252252
| Indexed (expr, ix) ->
253253
Set.Poly.union_list (expr_var_set expr :: List.map ix ~f:index_var_set)
254+
| Promotion (expr, _, _) -> expr_var_set expr
254255
| EAnd (expr1, expr2) | EOr (expr1, expr2) -> union_recur [expr1; expr2]
255256

256257
and index_var_set ix =
@@ -369,6 +370,7 @@ let rec expr_depth Expr.Fixed.{pattern; _} =
369370
+ max (expr_depth e)
370371
(Option.value ~default:0
371372
(List.max_elt ~compare:compare_int (List.map ~f:idx_depth l)) )
373+
| Promotion (expr, _, _) -> 1 + expr_depth expr
372374
| EAnd (e1, e2) | EOr (e1, e2) ->
373375
1
374376
+ Option.value ~default:0
@@ -413,6 +415,10 @@ let rec update_expr_ad_levels autodiffable_variables
413415
let e1 = update_expr_ad_levels autodiffable_variables e1 in
414416
let e2 = update_expr_ad_levels autodiffable_variables e2 in
415417
{pattern= EOr (e1, e2); meta= {e.meta with adlevel= ad_level_sup [e1; e2]}}
418+
| Promotion (expr, ut, ad) ->
419+
let expr' = update_expr_ad_levels autodiffable_variables expr in
420+
{ pattern= Promotion (expr', ut, ad)
421+
; meta= {e.meta with adlevel= ad_level_sup [expr']} }
416422
| Indexed (ixed, i_list) ->
417423
let ixed = update_expr_ad_levels autodiffable_variables ixed in
418424
let i_list =

src/analysis_and_optimization/Monotone_framework.ml

+13-11
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ let print_mfp to_string (mfp : (int, 'a entry_exit) Map.Poly.t)
2424
let rec free_vars_expr (e : Expr.Typed.t) =
2525
match e.pattern with
2626
| Var x -> Set.Poly.singleton x
27+
| Promotion (expr, _, _) -> free_vars_expr expr
2728
| Lit (_, _) -> Set.Poly.empty
2829
| FunApp (kind, l) -> free_vars_fnapp kind l
2930
| TernaryIf (e1, e2, e3) ->
@@ -133,15 +134,15 @@ let reverse (type l) (module F : FLOWGRAPH with type labels = l) =
133134
with type labels = l )
134135

135136
(** Modify the end nodes of a flowgraph to depend on its inits
136-
* To force the monotone framework to run until the program never changes
137-
* this function modifies the input `Flowgraph` so that it's end nodes
137+
* To force the monotone framework to run until the program never changes
138+
* this function modifies the input `Flowgraph` so that it's end nodes
138139
* depend on it's initial nodes. The inits of the reverse flowgraph are used
139-
* for this since we normally have both the forward and reverse flowgraphs
140+
* for this since we normally have both the forward and reverse flowgraphs
140141
* available.
141-
* @tparam l Type of the label for each flowgraph, most commonly an int
142+
* @tparam l Type of the label for each flowgraph, most commonly an int
142143
* @param Flowgraph The flowgraph to modify
143-
* @param RevFlowgraph The same flowgraph as `Flowgraph` but reversed.
144-
*
144+
* @param RevFlowgraph The same flowgraph as `Flowgraph` but reversed.
145+
*
145146
*)
146147
let make_circular_flowgraph (type l)
147148
(module Flowgraph : FLOWGRAPH with type labels = l)
@@ -544,6 +545,7 @@ let rec used_subexpressions_expr (e : Expr.Typed.t) =
544545
(Expr.Typed.Set.singleton e)
545546
( match e.pattern with
546547
| Var _ | Lit (_, _) -> Expr.Typed.Set.empty
548+
| Promotion (expr, _, _) -> used_subexpressions_expr expr
547549
| FunApp (k, l) ->
548550
Expr.Typed.Set.union_list
549551
(List.map ~f:used_subexpressions_expr (l @ Fun_kind.collect_exprs k))
@@ -1011,14 +1013,14 @@ let lazy_expressions_mfp
10111013
let used_not_latest_expressions_mfp = Mf4.mfp () in
10121014
(latest_expr, used_not_latest_expressions_mfp)
10131015

1014-
(** Run the minimal fixed point algorithm to deduce the smallest set of
1016+
(** Run the minimal fixed point algorithm to deduce the smallest set of
10151017
* variables that satisfy a set of conditions.
1016-
* @param Flowgraph The set of nodes to analyze
1017-
* @param flowgraph_to_mir Map of nodes to their actual values in the MIR
1018+
* @param Flowgraph The set of nodes to analyze
1019+
* @param flowgraph_to_mir Map of nodes to their actual values in the MIR
10181020
* @param initial_variables The set of variables to start in the set
1019-
* @param gen_variable Used in the transfer function to deduce variables
1021+
* @param gen_variable Used in the transfer function to deduce variables
10201022
* that should be in the set
1021-
*
1023+
*
10221024
*)
10231025
let minimal_variables_mfp
10241026
(module Circular_Fwd_Flowgraph : Monotone_framework_sigs.FLOWGRAPH

src/analysis_and_optimization/Optimize.ml

+5-2
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,9 @@ let rec inline_function_expression propto adt fim (Expr.Fixed.{pattern; _} as e)
224224
match pattern with
225225
| Var _ -> ([], [], e)
226226
| Lit (_, _) -> ([], [], e)
227+
| Promotion (expr, ut, ad) ->
228+
let d, sl, expr' = inline_function_expression propto adt fim expr in
229+
(d, sl, {e with pattern= Promotion (expr', ut, ad)})
227230
| FunApp (kind, es) -> (
228231
let d_list, s_list, es =
229232
inline_list (inline_function_expression propto adt fim) es in
@@ -1030,8 +1033,8 @@ let block_fixing mir =
10301033
(* TODO: add tests *)
10311034
(* TODO: add pass to get rid of redundant declarations? *)
10321035

1033-
(**
1034-
* A generic optimization pass for finding a minimal set of variables that
1036+
(**
1037+
* A generic optimization pass for finding a minimal set of variables that
10351038
* are generated by some circumstance, and then updating the MIR with that set.
10361039
* @param gen_variables: the variables that must be added to the set at
10371040
* the given statement

src/analysis_and_optimization/Partial_evaluator.ml

+1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) =
9393
pattern=
9494
( match e.pattern with
9595
| Var _ | Lit (_, _) -> e.pattern
96+
| Promotion (expr, ut, ad) -> Promotion (eval_expr expr, ut, ad)
9697
| FunApp (kind, l) -> (
9798
let l = List.map ~f:(eval_expr ~preserve_stability) l in
9899
match kind with

src/frontend/Ast.ml

+14-4
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ type ('e, 'f) expression =
4444
| ImagNumeral of string
4545
| FunApp of 'f * identifier * 'e list
4646
| CondDistApp of 'f * identifier * 'e list
47+
| Promotion of 'e * UnsizedType.t * UnsizedType.autodifftype
4748
(* GetLP is deprecated *)
4849
| GetLP
4950
| GetTarget
@@ -250,9 +251,14 @@ type typed_program = typed_statement program [@@deriving sexp, compare, map]
250251
(** Forgetful function from typed to untyped expressions *)
251252
let rec untyped_expression_of_typed_expression ({expr; emeta} : typed_expression)
252253
: untyped_expression =
253-
{ expr=
254-
map_expression untyped_expression_of_typed_expression (fun _ -> ()) expr
255-
; emeta= {loc= emeta.loc} }
254+
match expr with
255+
| Promotion (e, _, _) -> untyped_expression_of_typed_expression e
256+
| _ ->
257+
{ expr=
258+
map_expression untyped_expression_of_typed_expression
259+
(fun _ -> ())
260+
expr
261+
; emeta= {loc= emeta.loc} }
256262

257263
let rec untyped_lvalue_of_typed_lvalue ({lval; lmeta} : typed_lval) :
258264
untyped_lval =
@@ -304,7 +310,11 @@ let rec id_of_lvalue {lval; _} =
304310

305311
let rec get_loc_expr (e : untyped_expression) =
306312
match e.expr with
307-
| TernaryIf (e, _, _) | BinOp (e, _, _) | PostfixOp (e, _) | Indexed (e, _) ->
313+
| TernaryIf (e, _, _)
314+
|BinOp (e, _, _)
315+
|PostfixOp (e, _)
316+
|Indexed (e, _)
317+
|Promotion (e, _, _) ->
308318
get_loc_expr e
309319
| PrefixOp (_, e) | ArrayExpr (e :: _) | RowVectorExpr (e :: _) | Paren e ->
310320
e.emeta.loc.begin_loc

0 commit comments

Comments
 (0)