Skip to content

Commit 116c4ce

Browse files
committed
Promotions from signatures
1 parent 4393d84 commit 116c4ce

File tree

13 files changed

+835
-331
lines changed

13 files changed

+835
-331
lines changed

src/frontend/Ast.ml

+8-3
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,14 @@ type typed_program = typed_statement program [@@deriving sexp, compare, map]
251251
(** Forgetful function from typed to untyped expressions *)
252252
let rec untyped_expression_of_typed_expression ({expr; emeta} : typed_expression)
253253
: untyped_expression =
254-
{ expr=
255-
map_expression untyped_expression_of_typed_expression (fun _ -> ()) expr
256-
; emeta= {loc= emeta.loc} }
254+
match expr with
255+
| Promotion (e, _) -> untyped_expression_of_typed_expression e
256+
| _ ->
257+
{ expr=
258+
map_expression untyped_expression_of_typed_expression
259+
(fun _ -> ())
260+
expr
261+
; emeta= {loc= emeta.loc} }
257262

258263
let rec untyped_lvalue_of_typed_lvalue ({lval; lmeta} : typed_lval) :
259264
untyped_lval =

src/frontend/SignatureMismatch.ml

+43-26
Original file line numberDiff line numberDiff line change
@@ -110,41 +110,52 @@ let rec compare_errors e1 e2 =
110110
| SuffixMismatch _, _ | _, InputMismatch _ -> -1
111111
| InputMismatch _, _ | _, SuffixMismatch _ -> 1 ) )
112112

113+
type promotions = None | RealPromotion | ComplexPromotion
114+
115+
(*
116+
PROMOTION TODO: Not sure what depth actually means, or how it needs to change
117+
*)
113118
let rec check_same_type depth t1 t2 =
114119
let wrap_func = Result.map_error ~f:(fun e -> TypeMismatch (t1, t2, Some e)) in
115120
match (t1, t2) with
116-
| t1, t2 when t1 = t2 -> Ok ()
117-
| UnsizedType.(UReal, UInt) when depth < 1 -> Ok ()
118-
| UnsizedType.(UComplex, UInt) when depth < 1 -> Ok ()
119-
| UnsizedType.(UComplex, UReal) when depth < 1 -> Ok ()
121+
| t1, t2 when t1 = t2 -> Ok None
122+
| UnsizedType.(UReal, UInt) when depth < 1 -> Ok RealPromotion
123+
| UnsizedType.(UComplex, UInt) when depth < 1 -> Ok ComplexPromotion
124+
| UnsizedType.(UComplex, UReal) when depth < 1 -> Ok ComplexPromotion
120125
| UFun (_, _, s1, _), UFun (_, _, s2, _)
121126
when Fun_kind.without_propto s1 <> Fun_kind.without_propto s2 ->
122127
Error
123128
(SuffixMismatch (Fun_kind.without_propto s1, Fun_kind.without_propto s2))
124129
|> wrap_func
125130
| UFun (_, rt1, _, _), UFun (_, rt2, _, _) when rt1 <> rt2 ->
126131
Error (ReturnTypeMismatch (rt1, rt2)) |> wrap_func
127-
| UFun (l1, _, _, _), UFun (l2, _, _, _) ->
128-
check_compatible_arguments (depth + 1) l2 l1
129-
|> Result.map_error ~f:(fun e -> InputMismatch e)
130-
|> wrap_func
132+
| UFun (l1, _, _, _), UFun (l2, _, _, _) -> (
133+
match check_compatible_arguments (depth + 1) l2 l1 with
134+
| Ok _ -> Ok None
135+
| Error e -> Error (InputMismatch e) |> wrap_func )
136+
(* PROMOTION TODO: Need to allow array promotion, this doesn't quite work? *)
137+
| UArray ut1, UArray ut2 ->
138+
check_same_type depth ut1 ut2
139+
|> Result.map_error ~f:(function
140+
| TypeMismatch (_, _, _) -> TypeMismatch (t1, t2, None)
141+
| e -> e )
131142
| t1, t2 -> Error (TypeMismatch (t1, t2, None))
132143

133-
and check_compatible_arguments depth args1 args2 =
134-
match List.zip args1 args2 with
144+
and check_compatible_arguments depth typs args2 :
145+
(promotions list, function_mismatch) result =
146+
match List.zip typs args2 with
135147
| List.Or_unequal_lengths.Unequal_lengths ->
136-
Error (ArgNumMismatch (List.length args1, List.length args2))
148+
Error (ArgNumMismatch (List.length typs, List.length args2))
137149
| Ok l ->
138-
List.find_mapi l ~f:(fun i ((ad1, ut1), (ad2, ut2)) ->
150+
List.mapi l ~f:(fun i ((ad1, ut1), (ad2, ut2)) ->
139151
match check_same_type depth ut1 ut2 with
140-
| Error e -> Some (ArgError (i + 1, e))
141-
| Ok _ ->
142-
if ad1 = ad2 then None
152+
| Error e -> Error (ArgError (i + 1, e))
153+
| Ok p ->
154+
if ad1 = ad2 then Ok p
143155
else if depth < 2 && UnsizedType.autodifftype_can_convert ad1 ad2
144-
then None
145-
else Some (ArgError (i + 1, DataOnlyError)) )
146-
|> Option.map ~f:Result.fail
147-
|> Option.value ~default:(Ok ())
156+
then Ok p
157+
else Error (ArgError (i + 1, DataOnlyError)) )
158+
|> Result.all
148159

149160
let check_compatible_arguments_mod_conv = check_compatible_arguments 0
150161
let max_n_errors = 5
@@ -157,13 +168,19 @@ let extract_function_types f =
157168
Some (return, args, (fun x -> UserDefined x), mem)
158169
| _ -> None
159170

160-
let arg_type x = Ast.(x.emeta.ad_level, x.emeta.type_)
161-
let get_arg_types = List.map ~f:arg_type
171+
let promote es promotions =
172+
List.map2_exn es promotions ~f:(fun (exp : Ast.typed_expression) prom ->
173+
let emeta = exp.emeta in
174+
match prom with
175+
| RealPromotion when emeta.type_ <> UReal ->
176+
Ast.{expr= Ast.Promotion (exp, UReal); emeta}
177+
| ComplexPromotion when emeta.type_ <> UComplex ->
178+
{expr= Promotion (exp, UComplex); emeta}
179+
| _ -> exp )
162180

163-
let returntype env name arg_exprs =
181+
let returntype env name args =
164182
(* NB: Variadic arguments are special-cased in the typechecker and not handled here *)
165183
let name = Utils.stdlib_distribution_name name in
166-
let args = get_arg_types arg_exprs in
167184
Environment.find env name
168185
|> List.filter_map ~f:extract_function_types
169186
|> List.sort ~compare:(fun (x, _, _, _) (y, _, _, _) ->
@@ -173,7 +190,7 @@ let returntype env name arg_exprs =
173190
~f:(fun errors (rt, tys, funkind_constructor, _) ->
174191
match check_compatible_arguments 0 tys args with
175192
(* TODO instead of unit, return Ast.typed_expr list which could contain promotions*)
176-
| Ok () -> Stop (Ok (rt, funkind_constructor))
193+
| Ok p -> Stop (Ok (rt, funkind_constructor, p))
177194
| Error e -> Continue (((rt, tys), e) :: errors) )
178195
~finish:(fun errors ->
179196
let errors =
@@ -203,13 +220,13 @@ let check_variadic_args allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys
203220
if suffix = FnPlain || (allow_lpdf && suffix = FnLpdf ()) then
204221
match check_compatible_arguments 1 mandatory mandatory_fun_arg_tys with
205222
| Error x -> wrap_func_error (InputMismatch x)
206-
| Ok () -> (
223+
| Ok _ -> (
207224
match check_same_type 1 return_type fun_return with
208225
| Error _ ->
209226
wrap_func_error
210227
(ReturnTypeMismatch
211228
(ReturnType fun_return, ReturnType return_type) )
212-
| Ok () ->
229+
| Ok _ ->
213230
let expected_args =
214231
((UnsizedType.AutoDiffable, func_type) :: mandatory_arg_tys)
215232
@ variadic_arg_tys in

src/frontend/SignatureMismatch.mli

+11-4
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,23 @@ type signature_error =
1919
(UnsizedType.returntype * (UnsizedType.autodifftype * UnsizedType.t) list)
2020
* function_mismatch
2121

22+
type promotions = private None | RealPromotion | ComplexPromotion
23+
2224
val check_compatible_arguments_mod_conv :
2325
(UnsizedType.autodifftype * UnsizedType.t) list
2426
-> (UnsizedType.autodifftype * UnsizedType.t) list
25-
-> (unit (* Ast.typed_expression list *), function_mismatch) result
27+
-> (promotions list, function_mismatch) result
28+
29+
val promote :
30+
Ast.typed_expression list -> promotions list -> Ast.typed_expression list
2631

2732
val returntype :
2833
Environment.t
2934
-> string
30-
-> Ast.typed_expression list
31-
-> ( UnsizedType.returntype * (bool Middle.Fun_kind.suffix -> Ast.fun_kind)
35+
-> (UnsizedType.autodifftype * UnsizedType.t) list
36+
-> ( UnsizedType.returntype
37+
* (bool Middle.Fun_kind.suffix -> Ast.fun_kind)
38+
* promotions list
3239
, signature_error list * bool )
3340
result
3441

@@ -38,7 +45,7 @@ val check_variadic_args :
3845
-> (UnsizedType.autodifftype * UnsizedType.t) list
3946
-> UnsizedType.t
4047
-> (UnsizedType.autodifftype * UnsizedType.t) list
41-
-> ( unit
48+
-> ( promotions list
4249
, (UnsizedType.autodifftype * UnsizedType.t) list * function_mismatch )
4350
result
4451

src/frontend/Typechecker.ml

+22-12
Original file line numberDiff line numberDiff line change
@@ -424,15 +424,17 @@ let check_fn ~is_cond_dist loc tenv id es =
424424
(Env.nearest_ident tenv id.name) )
425425
|> error
426426
| _ (* a function *) -> (
427-
match SignatureMismatch.returntype tenv id.name es with
428-
| Ok (Void, _) ->
427+
match SignatureMismatch.returntype tenv id.name (get_arg_types es) with
428+
| Ok (Void, _, _) ->
429429
Semantic_error.returning_fn_expected_nonreturning_found loc id.name
430430
|> error
431-
| Ok (ReturnType ut, fnk) ->
431+
| Ok (ReturnType ut, fnk, promotions) ->
432432
mk_typed_expression
433433
~expr:
434434
(mk_fun_app ~is_cond_dist
435-
(fnk (Fun_kind.suffix_from_name id.name), id, es) )
435+
( fnk (Fun_kind.suffix_from_name id.name)
436+
, id
437+
, SignatureMismatch.promote es promotions ) )
436438
~ad_level:(expr_ad_lub es) ~type_:ut ~loc
437439
| Error x ->
438440
es
@@ -458,9 +460,11 @@ let check_reduce_sum ~is_cond_dist loc id es =
458460
SignatureMismatch.check_variadic_args true mandatory_args
459461
mandatory_fun_args UReal (get_arg_types es)
460462
with
461-
| Ok () ->
463+
| Ok promotions ->
462464
mk_typed_expression
463-
~expr:(mk_fun_app ~is_cond_dist (StanLib FnPlain, id, es))
465+
~expr:
466+
(mk_fun_app ~is_cond_dist
467+
(StanLib FnPlain, id, SignatureMismatch.promote es promotions) )
464468
~ad_level:(expr_ad_lub es) ~type_:UnsizedType.UReal ~loc
465469
| Error (expected_args, err) ->
466470
Semantic_error.illtyped_reduce_sum loc id.name
@@ -498,9 +502,11 @@ let check_variadic_ode ~is_cond_dist loc id es =
498502
Stan_math_signatures.variadic_ode_mandatory_fun_args
499503
Stan_math_signatures.variadic_ode_fun_return_type (get_arg_types es)
500504
with
501-
| Ok () ->
505+
| Ok promotions ->
502506
mk_typed_expression
503-
~expr:(mk_fun_app ~is_cond_dist (StanLib FnPlain, id, es))
507+
~expr:
508+
(mk_fun_app ~is_cond_dist
509+
(StanLib FnPlain, id, SignatureMismatch.promote es promotions) )
504510
~ad_level:(expr_ad_lub es)
505511
~type_:Stan_math_signatures.variadic_ode_return_type ~loc
506512
| Error (expected_args, err) ->
@@ -702,12 +708,16 @@ let check_nrfn loc tenv id es =
702708
(Env.nearest_ident tenv id.name)
703709
|> error
704710
| _ (* a function *) -> (
705-
match SignatureMismatch.returntype tenv id.name es with
706-
| Ok (Void, fnk) ->
711+
match SignatureMismatch.returntype tenv id.name (get_arg_types es) with
712+
| Ok (Void, fnk, promotions) ->
707713
mk_typed_statement
708-
~stmt:(NRFunApp (fnk (Fun_kind.suffix_from_name id.name), id, es))
714+
~stmt:
715+
(NRFunApp
716+
( fnk (Fun_kind.suffix_from_name id.name)
717+
, id
718+
, SignatureMismatch.promote es promotions ) )
709719
~return_type:NoReturnType ~loc
710-
| Ok (ReturnType _, _) ->
720+
| Ok (ReturnType _, _, _) ->
711721
Semantic_error.nonreturning_fn_expected_returning_found loc id.name
712722
|> error
713723
| Error x ->

src/stan_math_backend/Expression_gen.ml

+4-2
Original file line numberDiff line numberDiff line change
@@ -528,8 +528,10 @@ and pp_expr ppf Expr.Fixed.({pattern; meta} as e) =
528528
| Lit (Imaginary, s) -> pf ppf "to_complex(0, %s)" s
529529
| Lit ((Real | Int), s) -> pf ppf "%s" s
530530
| Promotion (expr, ut) ->
531-
pf ppf "stan::math::promote_scalar_t<%a>(%a)" pp_unsizedtype_local
532-
(meta.adlevel, ut) pp_expr expr
531+
if is_scalar expr && ut = UReal then pp_expr ppf expr
532+
else
533+
pf ppf "stan::math::promote_scalar_t<%a>(%a)" pp_unsizedtype_local
534+
(meta.adlevel, ut) pp_expr expr
533535
| FunApp
534536
( StanLib (op, _, _)
535537
, [ { meta= {type_= URowVector; _}

test/integration/bad/algebra_solver/stanc.expected

-76
Original file line numberDiff line numberDiff line change
@@ -128,47 +128,9 @@ where F2 = (vector, vector, data array[] real, data array[] int) => vector
128128
(<F2>, vector, vector, data array[] real, data array[] int) => vector
129129
Expected 5 arguments but found 8 arguments.
130130
$ ../../../../../install/default/bin/stanc bad_newton_x_r_type.stan
131-
Semantic error in 'bad_newton_x_r_type.stan', line 31, column 10 to column 69:
132-
-------------------------------------------------
133-
29: transformed parameters {
134-
30: vector[2] y_s_p;
135-
31: y_s_p = algebra_solver_newton(algebra_system, y, theta_p, x_r, x_i);
136-
^
137-
32: }
138-
33:
139-
-------------------------------------------------
140131

141-
Ill-typed arguments supplied to function 'algebra_solver_newton':
142-
(<F1>, vector, vector, array[] int, array[] int)
143-
where F1 = (vector, vector, array[] real, array[] int) => vector
144-
Available signatures:
145-
(<F2>, vector, vector, data array[] real, data array[] int) => vector
146-
where F2 = (vector, vector, data array[] real, data array[] int) => vector
147-
The fourth argument must be array[] real but got array[] int
148-
(<F2>, vector, vector, data array[] real, data array[] int, data real,
149-
data real, data real) => vector
150-
Expected 8 arguments but found 5 arguments.
151132
$ ../../../../../install/default/bin/stanc bad_newton_x_r_type_control.stan
152-
Semantic error in 'bad_newton_x_r_type_control.stan', line 31, column 10 to column 85:
153-
-------------------------------------------------
154-
29: transformed parameters {
155-
30: vector[2] y_s_p;
156-
31: y_s_p = algebra_solver_newton(algebra_system, y, theta_p, x_r, x_i, 0.01, 0.01, 10);
157-
^
158-
32: }
159-
33:
160-
-------------------------------------------------
161133

162-
Ill-typed arguments supplied to function 'algebra_solver_newton':
163-
(<F1>, vector, vector, array[] int, array[] int, real, real, int)
164-
where F1 = (vector, vector, array[] real, array[] int) => vector
165-
Available signatures:
166-
(<F2>, vector, vector, data array[] real, data array[] int, data real,
167-
data real, data real) => vector
168-
where F2 = (vector, vector, data array[] real, data array[] int) => vector
169-
The fourth argument must be array[] real but got array[] int
170-
(<F2>, vector, vector, data array[] real, data array[] int) => vector
171-
Expected 5 arguments but found 8 arguments.
172134
$ ../../../../../install/default/bin/stanc bad_newton_x_r_var_type.stan
173135
Semantic error in 'bad_newton_x_r_var_type.stan', line 31, column 10 to column 71:
174136
-------------------------------------------------
@@ -374,47 +336,9 @@ where F2 = (vector, vector, data array[] real, data array[] int) => vector
374336
(<F2>, vector, vector, data array[] real, data array[] int) => vector
375337
Expected 5 arguments but found 8 arguments.
376338
$ ../../../../../install/default/bin/stanc bad_x_r_type.stan
377-
Semantic error in 'bad_x_r_type.stan', line 31, column 10 to column 62:
378-
-------------------------------------------------
379-
29: transformed parameters {
380-
30: vector[2] y_s_p;
381-
31: y_s_p = algebra_solver(algebra_system, y, theta_p, x_r, x_i);
382-
^
383-
32: }
384-
33:
385-
-------------------------------------------------
386339

387-
Ill-typed arguments supplied to function 'algebra_solver':
388-
(<F1>, vector, vector, array[] int, array[] int)
389-
where F1 = (vector, vector, array[] real, array[] int) => vector
390-
Available signatures:
391-
(<F2>, vector, vector, data array[] real, data array[] int) => vector
392-
where F2 = (vector, vector, data array[] real, data array[] int) => vector
393-
The fourth argument must be array[] real but got array[] int
394-
(<F2>, vector, vector, data array[] real, data array[] int, data real,
395-
data real, data real) => vector
396-
Expected 8 arguments but found 5 arguments.
397340
$ ../../../../../install/default/bin/stanc bad_x_r_type_control.stan
398-
Semantic error in 'bad_x_r_type_control.stan', line 31, column 10 to column 78:
399-
-------------------------------------------------
400-
29: transformed parameters {
401-
30: vector[2] y_s_p;
402-
31: y_s_p = algebra_solver(algebra_system, y, theta_p, x_r, x_i, 0.01, 0.01, 10);
403-
^
404-
32: }
405-
33:
406-
-------------------------------------------------
407341

408-
Ill-typed arguments supplied to function 'algebra_solver':
409-
(<F1>, vector, vector, array[] int, array[] int, real, real, int)
410-
where F1 = (vector, vector, array[] real, array[] int) => vector
411-
Available signatures:
412-
(<F2>, vector, vector, data array[] real, data array[] int, data real,
413-
data real, data real) => vector
414-
where F2 = (vector, vector, data array[] real, data array[] int) => vector
415-
The fourth argument must be array[] real but got array[] int
416-
(<F2>, vector, vector, data array[] real, data array[] int) => vector
417-
Expected 5 arguments but found 8 arguments.
418342
$ ../../../../../install/default/bin/stanc bad_x_r_var_type.stan
419343
Semantic error in 'bad_x_r_var_type.stan', line 31, column 10 to column 64:
420344
-------------------------------------------------

test/integration/bad/stanc.expected

-21
Original file line numberDiff line numberDiff line change
@@ -1780,28 +1780,7 @@ Semantic error in 'prob-poisson_log-trunc-low.stan', line 11, column 2 to column
17801780

17811781
Truncation is only defined if distribution has _lcdf and _lccdf functions implemented with appropriate signature.
17821782
$ ../../../../install/default/bin/stanc propto-bad1.stan
1783-
Semantic error in 'propto-bad1.stan', line 2, column 14 to column 37:
1784-
-------------------------------------------------
1785-
1: model {
1786-
2: target += normal_lupdf(1| {0}, 1);
1787-
^
1788-
3: }
1789-
-------------------------------------------------
17901783

1791-
Ill-typed arguments supplied to function 'normal_lupdf':
1792-
(int, array[] int, int)
1793-
Available signatures:
1794-
(real, real, real) => real
1795-
The second argument must be real but got array[] int
1796-
(real, real, array[] real) => real
1797-
The second argument must be real but got array[] int
1798-
(real, real, vector) => real
1799-
The second argument must be real but got array[] int
1800-
(real, real, row_vector) => real
1801-
The second argument must be real but got array[] int
1802-
(real, vector, real) => real
1803-
The second argument must be vector but got array[] int
1804-
(Additional signatures omitted)
18051784
$ ../../../../install/default/bin/stanc real-pdf.stan
18061785
Semantic error in 'real-pdf.stan', line 2, column 2 to line 4, column 3:
18071786
-------------------------------------------------

0 commit comments

Comments
 (0)