@@ -17,7 +17,7 @@ let rec matrix_set Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{type_; _} as meta
17
17
if UnsizedType. contains_eigen_type type_ then union_recur exprs
18
18
else Set.Poly. empty
19
19
| TernaryIf (_ , expr2 , expr3 ) -> union_recur [expr2; expr3]
20
- | Indexed (expr , _ ) -> matrix_set expr
20
+ | Indexed (expr , _ ) | Promotion ( expr , _ , _ ) -> matrix_set expr
21
21
| EAnd (expr1 , expr2 ) | EOr (expr1 , expr2 ) -> union_recur [expr1; expr2]
22
22
else Set.Poly. empty
23
23
@@ -45,7 +45,7 @@ let is_nonzero_subset ~set ~subset =
45
45
&& not (Set.Poly. is_empty subset)
46
46
47
47
(* *
48
- * Check an expression to count how many times we see a single index.
48
+ * Check an expression to count how many times we see a single index.
49
49
* @param acc An accumulator from previous folds of multiple expressions.
50
50
* @param pattern The expression patterns to match against
51
51
*)
@@ -62,6 +62,7 @@ let rec count_single_idx_exprs (acc : int) Expr.Fixed.{pattern; _} : int =
62
62
acc
63
63
+ count_single_idx_exprs 0 idx_expr
64
64
+ List. fold_left ~init: 0 ~f: count_single_idx indexed
65
+ | Promotion (expr , _ , _ ) -> count_single_idx_exprs acc expr
65
66
| EAnd (lhs , rhs ) ->
66
67
acc + count_single_idx_exprs 0 lhs + count_single_idx_exprs 0 rhs
67
68
| EOr (lhs , rhs ) ->
@@ -72,9 +73,9 @@ let rec count_single_idx_exprs (acc : int) Expr.Fixed.{pattern; _} : int =
72
73
(* *
73
74
* Check an Index to count how many times we see a single index.
74
75
* @param acc An accumulator from previous folds of multiple expressions.
75
- * @param idx An Index to match. For Single types this adds 1 to the
76
- * acc. For Upfrom and MultiIndex types we check the inner expression
77
- * for a Single index. All and Between cannot be Single cell access
76
+ * @param idx An Index to match. For Single types this adds 1 to the
77
+ * acc. For Upfrom and MultiIndex types we check the inner expression
78
+ * for a Single index. All and Between cannot be Single cell access
78
79
* and so pass acc along.
79
80
*)
80
81
and count_single_idx (acc : int ) (idx : Expr.Typed.Meta.t Expr.Fixed.t Index.t )
@@ -84,13 +85,13 @@ and count_single_idx (acc : int) (idx : Expr.Typed.Meta.t Expr.Fixed.t Index.t)
84
85
| Single _ -> acc + 1
85
86
86
87
(* *
87
- * Find indices on Matrix and Vector types that perform single
88
+ * Find indices on Matrix and Vector types that perform single
88
89
* cell access. Returns true if it finds
89
90
* a vector, row vector, matrix, or matrix with single cell access
90
- * as well as an array of any of the above that is accessing the
91
+ * as well as an array of any of the above that is accessing the
91
92
* inner matrix types cell.
92
93
* @param ut An UnsizedType to match against.
93
- * @param index This list is checked for Single cell access
94
+ * @param index This list is checked for Single cell access
94
95
* either at the top level or within the `Index` types of the list.
95
96
*)
96
97
let rec is_uni_eigen_loop_indexing in_loop (ut : UnsizedType.t )
@@ -122,7 +123,7 @@ let is_fun_soa_supported name exprs =
122
123
* see the docs for `query_initial_demotable_funs`.
123
124
* @param in_loop a boolean to signify if the expression exists inside
124
125
* of a loop. If so, the names of matrix and vector like objects
125
- * will be returned if the matrix or vector is accessed by single
126
+ * will be returned if the matrix or vector is accessed by single
126
127
* cell indexing.
127
128
*)
128
129
let rec query_initial_demotable_expr (in_loop : bool ) ~(acc : string Set.Poly.t )
@@ -147,6 +148,7 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
147
148
Set.Poly. union acc index_demotes
148
149
| Var (_ : string ) | Lit ((_ : Expr.Fixed.Pattern.litType ), (_ : string )) ->
149
150
acc
151
+ | Promotion (expr , _ , _ ) -> query_expr acc expr
150
152
| TernaryIf (predicate , texpr , fexpr ) ->
151
153
let predicate_demotes = query_expr acc predicate in
152
154
Set.Poly. union
@@ -159,18 +161,18 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
159
161
Set.Poly. union (query_expr full_lhs_rhs lhs) (query_expr full_lhs_rhs rhs)
160
162
161
163
(* *
162
- * Query a function to detect if it or any of its used
164
+ * Query a function to detect if it or any of its used
163
165
* expression's objects or expressions should be demoted to AoS.
164
166
*
165
167
* The logic here demotes the expressions in a function to AoS if
166
- * the function's inner expression returns has a meta type containing a matrix
168
+ * the function's inner expression returns has a meta type containing a matrix
167
169
* and either of :
168
170
* (1) The function is user defined and the UDFs inputs are matrices.
169
171
* (2) The Stan math function cannot support AoS
170
172
* @param in_loop A boolean to specify the logic of indexing expressions. See
171
173
* `query_initial_demotable_expr` for an explanation of the logic.
172
- * @param kind The function type, for StanLib functions we check if the
173
- * function supports SoA and for UserDefined functions we always fail
174
+ * @param kind The function type, for StanLib functions we check if the
175
+ * function supports SoA and for UserDefined functions we always fail
174
176
* and return back all of the names of the objects passed in expressions
175
177
* to the UDF.
176
178
* exprs The expression list passed to the functions.
@@ -212,7 +214,8 @@ let rec is_any_soa_supported_expr
212
214
match pattern with
213
215
| FunApp (kind , (exprs : Expr.Typed.Meta.t Expr.Fixed.t list )) ->
214
216
is_any_soa_supported_fun_expr kind exprs
215
- | Indexed (expr , (_ : Typed.Meta.t Fixed.t Index.t list )) ->
217
+ | Indexed (expr, (_ : Typed.Meta.t Fixed.t Index.t list ))
218
+ | Promotion (expr , _ , _ ) ->
216
219
is_any_soa_supported_expr expr
217
220
| Var (_ : string ) | Lit ((_ : Expr.Fixed.Pattern.litType ), (_ : string )) ->
218
221
true
@@ -248,7 +251,8 @@ let rec is_any_ad_real_data_matrix_expr
248
251
match pattern with
249
252
| FunApp (kind , (exprs : Expr.Typed.Meta.t Expr.Fixed.t list )) ->
250
253
is_any_ad_real_data_matrix_expr_fun kind exprs
251
- | Indexed (expr , _ ) -> is_any_ad_real_data_matrix_expr expr
254
+ | Indexed (expr , _ ) | Promotion (expr , _ , _ ) ->
255
+ is_any_ad_real_data_matrix_expr expr
252
256
| Var (_ : string ) | Lit ((_ : Expr.Fixed.Pattern.litType ), (_ : string )) ->
253
257
false
254
258
| TernaryIf (_ , texpr , fexpr ) ->
@@ -303,15 +307,15 @@ and is_any_ad_real_data_matrix_expr_fun (kind : 'a Fun_kind.t)
303
307
304
308
(* *
305
309
* Query to find the initial set of objects in statements that cannot be SoA.
306
- * This is mostly recursive over expressions and statements, with the exception of
310
+ * This is mostly recursive over expressions and statements, with the exception of
307
311
* functions and Assignments.
308
312
*
309
313
* For assignments:
310
314
* We demote the LHS variable if any of the following are true:
311
- * 1. None of the RHS's functions are able to accept SoA matrices
315
+ * 1. None of the RHS's functions are able to accept SoA matrices
312
316
* and the rhs is not an internal compiler function.
313
317
* 2. A single cell of the LHS is being assigned within a loop.
314
- * 3. The top level expression on the RHS is a combination of only
318
+ * 3. The top level expression on the RHS is a combination of only
315
319
* data matrices and scalar types. Operations on data matrix and
316
320
* scalar values in Stan math will return a AoS matrix. We currently
317
321
* have no way to tell Stan math to return a SoA matrix.
@@ -408,11 +412,11 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
408
412
(* * Look through a statement to see whether the objects used in it need to be
409
413
* modified from SoA to AoS. Returns the set of object names that need demoted
410
414
* in a statement, if any.
411
- * This function looks at Assignment statements, and returns back the
415
+ * This function looks at Assignment statements, and returns back the
412
416
* set of top level object names given:
413
- * 1. If the name of the lhs assignee is in the `aos_exits`, all the names
417
+ * 1. If the name of the lhs assignee is in the `aos_exits`, all the names
414
418
* of the expressions with a type containing a matrix are returned.
415
- * 2. If the names of the rhs objects containing matrix types are in the subset of
419
+ * 2. If the names of the rhs objects containing matrix types are in the subset of
416
420
* aos_exits.
417
421
* @param aos_exits A set of variables that can be demoted.
418
422
* @param pattern The Stmt pattern to query.
@@ -437,15 +441,15 @@ let query_demotable_stmt (aos_exits : string Set.Poly.t)
437
441
438
442
(* *
439
443
* Modify a function and it's subexpressions from SoA <-> AoS and vice versa.
440
- * This performs demotion for sub expressions recursively. The top level
441
- * expression and it's sub expressions are demoted to SoA if
442
- * 1. The names of the variables in the subexpressions returning
444
+ * This performs demotion for sub expressions recursively. The top level
445
+ * expression and it's sub expressions are demoted to SoA if
446
+ * 1. The names of the variables in the subexpressions returning
443
447
* objects holding matrices are all in the modifiable set.
444
448
* 2. The function does not support SoA
445
449
* 3. The `force` argument is `true`
446
- * @param force_demotion If true, forces an expression and it's sub-expressions
450
+ * @param force_demotion If true, forces an expression and it's sub-expressions
447
451
* to be AoS.
448
- * @param modifiable_set The set of names that are either demotable
452
+ * @param modifiable_set The set of names that are either demotable
449
453
* to AoS or promotable to SoA.
450
454
* @param kind A `Fun_kind.t`
451
455
* @param exprs A list of expressions going into the function.
@@ -474,15 +478,15 @@ let rec modify_kind ?force_demotion:(force = false)
474
478
( kind
475
479
, List. map ~f: (modify_expr ~force_demotion: force modifiable_set) exprs )
476
480
477
- (* *
481
+ (* *
478
482
* Modify an expression and it's subexpressions from SoA <-> AoS
479
483
* and vice versa. The only real paths in the below is on the
480
484
* functions and ternary expressions.
481
485
*
482
486
* The logic for functions is defined in `modify_kind`.
483
487
* `TernaryIf` is forcefully demoted to AoS if the type of the expression
484
488
* contains a matrix.
485
- * @param force_demotion If true, forces an expression and it's sub-expressions
489
+ * @param force_demotion If true, forces an expression and it's sub-expressions
486
490
* to be AoS.
487
491
* @param modifiable_set The name of the variables whose
488
492
* associated expressions we want to modify.
@@ -519,11 +523,13 @@ and modify_expr_pattern ?force_demotion:(force = false)
519
523
, List. map ~f: (Index. map (mod_expr ~force_demotion: force)) indexed )
520
524
| EAnd (lhs , rhs ) -> EAnd (mod_expr lhs, mod_expr rhs)
521
525
| EOr (lhs , rhs ) -> EOr (mod_expr lhs, mod_expr rhs)
526
+ | Promotion (expr , type_ , ad_level ) ->
527
+ Promotion (mod_expr expr, type_, ad_level)
522
528
| Var (_ : string ) | Lit ((_ : Expr.Fixed.Pattern.litType ), (_ : string )) ->
523
529
pattern
524
530
525
- (* *
526
- * Given a Set of strings containing the names of objects that can be
531
+ (* *
532
+ * Given a Set of strings containing the names of objects that can be
527
533
* modified from AoS <-> SoA and vice versa, modify them within the expression.
528
534
* @param mem_pattern The memory pattern to change expressions to.
529
535
* @param modifiable_set The name of the variables whose
@@ -536,15 +542,15 @@ and modify_expr ?force_demotion:(force = false)
536
542
pattern= modify_expr_pattern ~force_demotion: force modifiable_set pattern }
537
543
538
544
(* *
539
- * Modify statement patterns in the MIR from AoS <-> SoA and vice versa
545
+ * Modify statement patterns in the MIR from AoS <-> SoA and vice versa
540
546
* For `Decl` and `Assignment`'s reading in parameters, we demote to AoS
541
547
* if the `decl_id` (or assign name) is in the modifiable set and
542
548
* otherwise promote the statement to `SoA`.
543
549
* For general `Assignment` statements, we check if the assignee is in
544
550
* the demotable set. If so, we force demotion of all of the rhs expressions.
545
551
* All other statements recurse over their statements and expressions.
546
- *
547
- * @param pattern The statement pattern to modify
552
+ *
553
+ * @param pattern The statement pattern to modify
548
554
* @param modifiable_set The name of the variable we are searching for.
549
555
*)
550
556
let rec modify_stmt_pattern
@@ -624,11 +630,11 @@ let rec modify_stmt_pattern
624
630
| Skip | Break | Continue | Decl _ -> pattern
625
631
626
632
(* *
627
- * Modify statement patterns in the MIR from AoS <-> SoA and vice versa
628
- * @param mem_pattern A mem_pattern to modify expressions to. For the
629
- * given memory pattern, this modifies
633
+ * Modify statement patterns in the MIR from AoS <-> SoA and vice versa
634
+ * @param mem_pattern A mem_pattern to modify expressions to. For the
635
+ * given memory pattern, this modifies
630
636
* statement patterns and expressions to it.
631
- * @param stmt The statement to modify.
637
+ * @param stmt The statement to modify.
632
638
* @param modifiable_set The name of the variable we are searching for.
633
639
*)
634
640
and modify_stmt (Stmt.Fixed. {pattern; _} as stmt )
0 commit comments