Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add promoter expression #1004

Merged
merged 25 commits into from
Jan 12, 2022
Merged
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/analysis_and_optimization/Mir_utils.ml
Original file line number Diff line number Diff line change
@@ -243,6 +243,7 @@ let rec expr_var_set Expr.Fixed.{pattern; meta} =
| TernaryIf (expr1, expr2, expr3) -> union_recur [expr1; expr2; expr3]
| Indexed (expr, ix) ->
Set.Poly.union_list (expr_var_set expr :: List.map ix ~f:index_var_set)
| Promotion (expr, _) -> expr_var_set expr
| EAnd (expr1, expr2) | EOr (expr1, expr2) -> union_recur [expr1; expr2]

and index_var_set ix =
@@ -361,6 +362,7 @@ let rec expr_depth Expr.Fixed.{pattern; _} =
+ max (expr_depth e)
(Option.value ~default:0
(List.max_elt ~compare:compare_int (List.map ~f:idx_depth l)) )
| Promotion (expr, _) -> 1 + expr_depth expr
| EAnd (e1, e2) | EOr (e1, e2) ->
1
+ Option.value ~default:0
@@ -405,6 +407,10 @@ let rec update_expr_ad_levels autodiffable_variables
let e1 = update_expr_ad_levels autodiffable_variables e1 in
let e2 = update_expr_ad_levels autodiffable_variables e2 in
{pattern= EOr (e1, e2); meta= {e.meta with adlevel= ad_level_sup [e1; e2]}}
| Promotion (expr, ut) ->
let expr' = update_expr_ad_levels autodiffable_variables expr in
{ pattern= Promotion (expr', ut)
; meta= {e.meta with adlevel= ad_level_sup [expr']} }
| Indexed (ixed, i_list) ->
let ixed = update_expr_ad_levels autodiffable_variables ixed in
let i_list =
2 changes: 2 additions & 0 deletions src/analysis_and_optimization/Monotone_framework.ml
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@ let print_mfp to_string (mfp : (int, 'a entry_exit) Map.Poly.t)
let rec free_vars_expr (e : Expr.Typed.t) =
match e.pattern with
| Var x -> Set.Poly.singleton x
| Promotion (expr, _) -> free_vars_expr expr
| Lit (_, _) -> Set.Poly.empty
| FunApp (kind, l) -> free_vars_fnapp kind l
| TernaryIf (e1, e2, e3) ->
@@ -544,6 +545,7 @@ let rec used_subexpressions_expr (e : Expr.Typed.t) =
(Expr.Typed.Set.singleton e)
( match e.pattern with
| Var _ | Lit (_, _) -> Expr.Typed.Set.empty
| Promotion (expr, _) -> used_subexpressions_expr expr
| FunApp (k, l) ->
Expr.Typed.Set.union_list
(List.map ~f:used_subexpressions_expr (l @ Fun_kind.collect_exprs k))
7 changes: 5 additions & 2 deletions src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
@@ -224,6 +224,9 @@ let rec inline_function_expression propto adt fim (Expr.Fixed.{pattern; _} as e)
match pattern with
| Var _ -> ([], [], e)
| Lit (_, _) -> ([], [], e)
| Promotion (expr, ut) ->
let d, sl, expr' = inline_function_expression propto adt fim expr in
(d, sl, {e with pattern= Promotion (expr', ut)})
| FunApp (kind, es) -> (
let d_list, s_list, es =
inline_list (inline_function_expression propto adt fim) es in
@@ -1024,8 +1027,8 @@ let block_fixing mir =
(* TODO: add tests *)
(* TODO: add pass to get rid of redundant declarations? *)

(**
* A generic optimization pass for finding a minimal set of variables that
(**
* A generic optimization pass for finding a minimal set of variables that
* are generated by some circumstance, and then updating the MIR with that set.
* @param gen_variables: the variables that must be added to the set at
* the given statement
1 change: 1 addition & 0 deletions src/analysis_and_optimization/Partial_evaluator.ml
Original file line number Diff line number Diff line change
@@ -93,6 +93,7 @@ let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) =
pattern=
( match e.pattern with
| Var _ | Lit (_, _) -> e.pattern
| Promotion (expr, ut) -> Promotion (eval_expr expr, ut)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[optional] this could be a bit smarter, like promoting literals to literals and collapsing promotion of promotion.

| FunApp (kind, l) -> (
let l = List.map ~f:(eval_expr ~preserve_stability) l in
match kind with
18 changes: 14 additions & 4 deletions src/frontend/Ast.ml
Original file line number Diff line number Diff line change
@@ -44,6 +44,7 @@ type ('e, 'f) expression =
| ImagNumeral of string
| FunApp of 'f * identifier * 'e list
| CondDistApp of 'f * identifier * 'e list
| Promotion of 'e * UnsizedType.t
(* GetLP is deprecated *)
| GetLP
| GetTarget
@@ -250,9 +251,14 @@ type typed_program = typed_statement program [@@deriving sexp, compare, map]
(** Forgetful function from typed to untyped expressions *)
let rec untyped_expression_of_typed_expression ({expr; emeta} : typed_expression)
: untyped_expression =
{ expr=
map_expression untyped_expression_of_typed_expression (fun _ -> ()) expr
; emeta= {loc= emeta.loc} }
match expr with
| Promotion (e, _) -> untyped_expression_of_typed_expression e
| _ ->
{ expr=
map_expression untyped_expression_of_typed_expression
(fun _ -> ())
expr
; emeta= {loc= emeta.loc} }

let rec untyped_lvalue_of_typed_lvalue ({lval; lmeta} : typed_lval) :
untyped_lval =
@@ -304,7 +310,11 @@ let rec id_of_lvalue {lval; _} =

let rec get_loc_expr (e : untyped_expression) =
match e.expr with
| TernaryIf (e, _, _) | BinOp (e, _, _) | PostfixOp (e, _) | Indexed (e, _) ->
| TernaryIf (e, _, _)
|BinOp (e, _, _)
|PostfixOp (e, _)
|Indexed (e, _)
|Promotion (e, _) ->
get_loc_expr e
| PrefixOp (_, e) | ArrayExpr (e :: _) | RowVectorExpr (e :: _) | Paren e ->
e.emeta.loc.begin_loc
1 change: 1 addition & 0 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
@@ -81,6 +81,7 @@ and trans_expr {Ast.expr; Ast.emeta} =
FunApp (CompilerInternal FnMakeRowVec, trans_exprs eles) |> ewrap
| Indexed (lhs, indices) ->
Indexed (trans_expr lhs, List.map ~f:trans_idx indices) |> ewrap
| Promotion (e, ty) -> Promotion (trans_expr e, ty) |> ewrap

and trans_idx = function
| Ast.All -> All
2 changes: 1 addition & 1 deletion src/frontend/Canonicalize.ml
Original file line number Diff line number Diff line change
@@ -142,7 +142,7 @@ let rec no_parens {expr; emeta} =
| i -> map_index keep_parens i )
l )
; emeta }
| ArrayExpr _ | RowVectorExpr _ | FunApp _ | CondDistApp _ ->
| ArrayExpr _ | RowVectorExpr _ | FunApp _ | CondDistApp _ | Promotion _ ->
{expr= map_expression no_parens ident expr; emeta}

and keep_parens {expr; emeta} =
1 change: 1 addition & 0 deletions src/frontend/Pretty_printing.ml
Original file line number Diff line number Diff line change
@@ -272,6 +272,7 @@ and pp_expression ppf ({expr= e_content; emeta= {loc; _}} : untyped_expression)
| ArrayExpr es -> pf ppf "{@[%a}@]" pp_list_of_expression (es, loc)
| RowVectorExpr es -> pf ppf "[@[%a]@]" pp_list_of_expression (es, loc)
| Paren e -> pf ppf "(%a)" pp_expression e
| Promotion (e, _) -> pp_expression ppf e
| Indexed (e, l) -> (
match l with
| [] -> pf ppf "%a" pp_expression e
79 changes: 51 additions & 28 deletions src/frontend/SignatureMismatch.ml
Original file line number Diff line number Diff line change
@@ -110,37 +110,50 @@ let rec compare_errors e1 e2 =
| SuffixMismatch _, _ | _, InputMismatch _ -> -1
| InputMismatch _, _ | _, SuffixMismatch _ -> 1 ) )

type promotions = None | RealPromotion | ComplexPromotion

let rec check_same_type depth t1 t2 =
let wrap_func = Option.map ~f:(fun e -> TypeMismatch (t1, t2, Some e)) in
let wrap_func = Result.map_error ~f:(fun e -> TypeMismatch (t1, t2, Some e)) in
match (t1, t2) with
| t1, t2 when t1 = t2 -> None
| UnsizedType.(UReal, UInt) when depth < 1 -> None
| t1, t2 when t1 = t2 -> Ok None
| UnsizedType.(UReal, UInt) when depth < 1 -> Ok RealPromotion
| UnsizedType.(UComplex, UInt) when depth < 1 -> Ok ComplexPromotion
| UnsizedType.(UComplex, UReal) when depth < 1 -> Ok ComplexPromotion
(* Arrays: Try to recursively promote, but make sure the error is for these types,
not the recursive call *)
| UArray nt1, UArray nt2 ->
check_same_type depth nt1 nt2
|> Result.map_error ~f:(function
| TypeMismatch _ -> TypeMismatch (t1, t2, None)
| e -> e )
| UFun (_, _, s1, _), UFun (_, _, s2, _)
when Fun_kind.without_propto s1 <> Fun_kind.without_propto s2 ->
Some
Error
(SuffixMismatch (Fun_kind.without_propto s1, Fun_kind.without_propto s2))
|> wrap_func
| UFun (_, rt1, _, _), UFun (_, rt2, _, _) when rt1 <> rt2 ->
Some (ReturnTypeMismatch (rt1, rt2)) |> wrap_func
| UFun (l1, _, _, _), UFun (l2, _, _, _) ->
check_compatible_arguments (depth + 1) l2 l1
|> Option.map ~f:(fun e -> InputMismatch e)
|> wrap_func
| t1, t2 -> Some (TypeMismatch (t1, t2, None))
Error (ReturnTypeMismatch (rt1, rt2)) |> wrap_func
| UFun (l1, _, _, _), UFun (l2, _, _, _) -> (
match check_compatible_arguments (depth + 1) l2 l1 with
| Ok _ -> Ok None
| Error e -> Error (InputMismatch e) |> wrap_func )
| t1, t2 -> Error (TypeMismatch (t1, t2, None))

and check_compatible_arguments depth args1 args2 =
match List.zip args1 args2 with
and check_compatible_arguments depth typs args2 :
(promotions list, function_mismatch) result =
match List.zip typs args2 with
| List.Or_unequal_lengths.Unequal_lengths ->
Some (ArgNumMismatch (List.length args1, List.length args2))
Error (ArgNumMismatch (List.length typs, List.length args2))
| Ok l ->
List.find_mapi l ~f:(fun i ((ad1, ut1), (ad2, ut2)) ->
List.mapi l ~f:(fun i ((ad1, ut1), (ad2, ut2)) ->
match check_same_type depth ut1 ut2 with
| Some e -> Some (ArgError (i + 1, e))
| None ->
if ad1 = ad2 then None
| Error e -> Error (ArgError (i + 1, e))
| Ok p ->
if ad1 = ad2 then Ok p
else if depth < 2 && UnsizedType.autodifftype_can_convert ad1 ad2
then None
else Some (ArgError (i + 1, DataOnlyError)) )
then Ok p
else Error (ArgError (i + 1, DataOnlyError)) )
|> Result.all

let check_compatible_arguments_mod_conv = check_compatible_arguments 0
let max_n_errors = 5
@@ -153,6 +166,16 @@ let extract_function_types f =
Some (return, args, (fun x -> UserDefined x), mem)
| _ -> None

let promote es promotions =
List.map2_exn es promotions ~f:(fun (exp : Ast.typed_expression) prom ->
let emeta = exp.emeta in
match prom with
| RealPromotion when emeta.type_ <> UReal ->
Ast.{expr= Ast.Promotion (exp, UReal); emeta}
| ComplexPromotion when emeta.type_ <> UComplex ->
{expr= Promotion (exp, UComplex); emeta}
| _ -> exp )

let returntype env name args =
(* NB: Variadic arguments are special-cased in the typechecker and not handled here *)
let name = Utils.stdlib_distribution_name name in
@@ -164,8 +187,8 @@ let returntype env name args =
|> List.fold_until ~init:[]
~f:(fun errors (rt, tys, funkind_constructor, _) ->
match check_compatible_arguments 0 tys args with
| None -> Stop (Ok (rt, funkind_constructor))
| Some e -> Continue (((rt, tys), e) :: errors) )
| Ok p -> Stop (Ok (rt, funkind_constructor, p))
| Error e -> Continue (((rt, tys), e) :: errors) )
~finish:(fun errors ->
let errors =
List.sort errors ~compare:(fun (_, e1) (_, e2) ->
@@ -180,7 +203,7 @@ let check_variadic_args allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys
in
let minimal_args =
(UnsizedType.AutoDiffable, minimal_func_type) :: mandatory_arg_tys in
let wrap_err x = Some (minimal_args, ArgError (1, x)) in
let wrap_err x = Error (minimal_args, ArgError (1, x)) in
match args with
| ( _
, ( UnsizedType.UFun (fun_args, ReturnType return_type, suffix, _) as
@@ -193,22 +216,22 @@ let check_variadic_args allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys
let suffix = Fun_kind.without_propto suffix in
if suffix = FnPlain || (allow_lpdf && suffix = FnLpdf ()) then
match check_compatible_arguments 1 mandatory mandatory_fun_arg_tys with
| Some x -> wrap_func_error (InputMismatch x)
| None -> (
| Error x -> wrap_func_error (InputMismatch x)
| Ok _ -> (
match check_same_type 1 return_type fun_return with
| Some _ ->
| Error _ ->
wrap_func_error
(ReturnTypeMismatch
(ReturnType fun_return, ReturnType return_type) )
| None ->
| Ok _ ->
let expected_args =
((UnsizedType.AutoDiffable, func_type) :: mandatory_arg_tys)
@ variadic_arg_tys in
check_compatible_arguments 0 expected_args args
|> Option.map ~f:(fun x -> (expected_args, x)) )
|> Result.map_error ~f:(fun x -> (expected_args, x)) )
else wrap_func_error (SuffixMismatch (FnPlain, suffix))
| (_, x) :: _ -> TypeMismatch (minimal_func_type, x, None) |> wrap_err
| [] -> Some ([], ArgNumMismatch (List.length mandatory_arg_tys, 0))
| [] -> Error ([], ArgNumMismatch (List.length mandatory_arg_tys, 0))

let pp_signature_mismatch ppf (name, arg_tys, (sigs, omitted)) =
let open Fmt in
20 changes: 16 additions & 4 deletions src/frontend/SignatureMismatch.mli
Original file line number Diff line number Diff line change
@@ -19,16 +19,27 @@ type signature_error =
(UnsizedType.returntype * (UnsizedType.autodifftype * UnsizedType.t) list)
* function_mismatch

(** Indicate a promotion by the resulting type *)
type promotions = private None | RealPromotion | ComplexPromotion

val check_compatible_arguments_mod_conv :
(UnsizedType.autodifftype * UnsizedType.t) list
-> (UnsizedType.autodifftype * UnsizedType.t) list
-> function_mismatch option
-> (promotions list, function_mismatch) result

val promote :
Ast.typed_expression list -> promotions list -> Ast.typed_expression list
(** Given a list of expressions (arguments) and a list of [promotions],
return a list of expressions which include the
[Promotion] expression as appropiate *)

val returntype :
Environment.t
-> string
-> (UnsizedType.autodifftype * UnsizedType.t) list
-> ( UnsizedType.returntype * (bool Middle.Fun_kind.suffix -> Ast.fun_kind)
-> ( UnsizedType.returntype
* (bool Middle.Fun_kind.suffix -> Ast.fun_kind)
* promotions list
, signature_error list * bool )
result

@@ -38,8 +49,9 @@ val check_variadic_args :
-> (UnsizedType.autodifftype * UnsizedType.t) list
-> UnsizedType.t
-> (UnsizedType.autodifftype * UnsizedType.t) list
-> ((UnsizedType.autodifftype * UnsizedType.t) list * function_mismatch)
option
-> ( promotions list
, (UnsizedType.autodifftype * UnsizedType.t) list * function_mismatch )
result

val pp_signature_mismatch :
Format.formatter
40 changes: 27 additions & 13 deletions src/frontend/Typechecker.ml
Original file line number Diff line number Diff line change
@@ -425,14 +425,16 @@ let check_fn ~is_cond_dist loc tenv id es =
|> error
| _ (* a function *) -> (
match SignatureMismatch.returntype tenv id.name (get_arg_types es) with
| Ok (Void, _) ->
| Ok (Void, _, _) ->
Semantic_error.returning_fn_expected_nonreturning_found loc id.name
|> error
| Ok (ReturnType ut, fnk) ->
| Ok (ReturnType ut, fnk, promotions) ->
mk_typed_expression
~expr:
(mk_fun_app ~is_cond_dist
(fnk (Fun_kind.suffix_from_name id.name), id, es) )
( fnk (Fun_kind.suffix_from_name id.name)
, id
, SignatureMismatch.promote es promotions ) )
~ad_level:(expr_ad_lub es) ~type_:ut ~loc
| Error x ->
es
@@ -458,11 +460,13 @@ let check_reduce_sum ~is_cond_dist loc id es =
SignatureMismatch.check_variadic_args true mandatory_args
mandatory_fun_args UReal (get_arg_types es)
with
| None ->
| Ok promotions ->
mk_typed_expression
~expr:(mk_fun_app ~is_cond_dist (StanLib FnPlain, id, es))
~expr:
(mk_fun_app ~is_cond_dist
(StanLib FnPlain, id, SignatureMismatch.promote es promotions) )
~ad_level:(expr_ad_lub es) ~type_:UnsizedType.UReal ~loc
| Some (expected_args, err) ->
| Error (expected_args, err) ->
Semantic_error.illtyped_reduce_sum loc id.name
(List.map ~f:type_of_expr_typed es)
expected_args err
@@ -477,7 +481,7 @@ let check_reduce_sum ~is_cond_dist loc id es =
let expected_args, err =
SignatureMismatch.check_variadic_args true mandatory_args
mandatory_fun_args UReal (get_arg_types es)
|> Option.value_exn in
|> Result.error |> Option.value_exn in
Semantic_error.illtyped_reduce_sum_generic loc id.name
(List.map ~f:type_of_expr_typed es)
expected_args err
@@ -498,12 +502,14 @@ let check_variadic_ode ~is_cond_dist loc id es =
Stan_math_signatures.variadic_ode_mandatory_fun_args
Stan_math_signatures.variadic_ode_fun_return_type (get_arg_types es)
with
| None ->
| Ok promotions ->
mk_typed_expression
~expr:(mk_fun_app ~is_cond_dist (StanLib FnPlain, id, es))
~expr:
(mk_fun_app ~is_cond_dist
(StanLib FnPlain, id, SignatureMismatch.promote es promotions) )
~ad_level:(expr_ad_lub es)
~type_:Stan_math_signatures.variadic_ode_return_type ~loc
| Some (expected_args, err) ->
| Error (expected_args, err) ->
Semantic_error.illtyped_variadic_ode loc id.name
(List.map ~f:type_of_expr_typed es)
expected_args err
@@ -657,6 +663,10 @@ and check_expression cf tenv ({emeta; expr} : Ast.untyped_expression) :
es |> List.map ~f:ce |> check_funapp loc cf tenv ~is_cond_dist:false id
| CondDistApp ((), id, es) ->
es |> List.map ~f:ce |> check_funapp loc cf tenv ~is_cond_dist:true id
| Promotion (e, _) ->
(* Should never happen: promotions are produced during typechecking *)
Common.FatalError.fatal_error_msg
[%message "Promotion in untyped AST" (e : Ast.untyped_expression)]

and check_expression_of_int_type cf tenv e name =
let te = check_expression cf tenv e in
@@ -699,11 +709,15 @@ let check_nrfn loc tenv id es =
|> error
| _ (* a function *) -> (
match SignatureMismatch.returntype tenv id.name (get_arg_types es) with
| Ok (Void, fnk) ->
| Ok (Void, fnk, promotions) ->
mk_typed_statement
~stmt:(NRFunApp (fnk (Fun_kind.suffix_from_name id.name), id, es))
~stmt:
(NRFunApp
( fnk (Fun_kind.suffix_from_name id.name)
, id
, SignatureMismatch.promote es promotions ) )
~return_type:NoReturnType ~loc
| Ok (ReturnType _, _) ->
| Ok (ReturnType _, _, _) ->
Semantic_error.nonreturning_fn_expected_returning_found loc id.name
|> error
| Error x ->
5 changes: 5 additions & 0 deletions src/middle/Expr.ml
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@ module Fixed = struct
| EAnd of 'a * 'a
| EOr of 'a * 'a
| Indexed of 'a * 'a Index.t list
| Promotion of 'a * UnsizedType.t
[@@deriving sexp, hash, map, compare, fold]

let pp pp_e ppf = function
@@ -42,6 +43,8 @@ module Fixed = struct
indices
| EAnd (l, r) -> Fmt.pf ppf "%a && %a" pp_e l pp_e r
| EOr (l, r) -> Fmt.pf ppf "%a || %a" pp_e l pp_e r
| Promotion (from, ut) ->
Fmt.pf ppf "%a -> %a" pp_e from UnsizedType.pp ut

include Foldable.Make (struct
type nonrec 'a t = 'a t
@@ -146,6 +149,8 @@ module Labelled = struct
associate ~init:(associate ~init:(associate ~init:assocs e3) e2) e1
| Indexed (e, idxs) ->
List.fold idxs ~init:(associate ~init:assocs e) ~f:associate_index
(* Not sure?*)
| Promotion (e1, _) -> associate ~init:assocs e1

and associate_index assocs = function
| All -> assocs
1 change: 1 addition & 0 deletions src/middle/Expr.mli
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@ module Fixed : sig
| EAnd of 'a * 'a
| EOr of 'a * 'a
| Indexed of 'a * 'a Index.t list
| Promotion of 'a * UnsizedType.t
[@@deriving sexp, hash, compare]

include Pattern.S with type 'a t := 'a t
14 changes: 7 additions & 7 deletions src/middle/Stan_math_signatures.ml
Original file line number Diff line number Diff line change
@@ -339,6 +339,7 @@ let fst2 (a, _, _) = a
let thrd (_, _, c) = c

(* -- Querying stan_math_signatures -- *)
(* TODO: Remove/prefer to use SignatureMismatch instead of UnsizedType *)
let stan_math_returntype (name : string) (args : fun_arg list) =
let name = Utils.stdlib_distribution_name name in
let namematches = Hashtbl.find_multi stan_math_signatures name in
@@ -351,13 +352,12 @@ let stan_math_returntype (name : string) (args : fun_arg list) =
| x when is_reduce_sum_fn x -> Some (UnsizedType.ReturnType UReal)
| x when is_variadic_ode_fn x -> Some (UnsizedType.ReturnType (UArray UVector))
| _ ->
if List.length filteredmatches = 0 then None
(* Return the least return type in case there are multiple options (due to implicit UInt-UReal conversion), where UInt<UReal *)
else
Some
(List.hd_exn
(List.sort ~compare:UnsizedType.compare_returntype
(List.map ~f:fst2 filteredmatches) ) )
(* Return the least return type in case there are multiple options
(due to implicit UInt-UReal conversion), where UInt<UReal
*)
List.hd
(List.sort ~compare:UnsizedType.compare_returntype
(List.map ~f:fst2 filteredmatches) )

let is_stan_math_function_name name =
let name = Utils.stdlib_distribution_name name in
5 changes: 5 additions & 0 deletions src/stan_math_backend/Expression_gen.ml
Original file line number Diff line number Diff line change
@@ -497,6 +497,11 @@ and pp_expr ppf Expr.Fixed.({pattern; meta} as e) =
| Lit (Str, s) -> pf ppf "\"%s\"" (Cpp_str.escaped s)
| Lit (Imaginary, s) -> pf ppf "stan::math::to_complex(0, %s)" s
| Lit ((Real | Int), s) -> pf ppf "%s" s
| Promotion (expr, ut) ->
if is_scalar expr && ut = UReal then pp_expr ppf expr
else
pf ppf "stan::math::promote_scalar<%a>(%a)" pp_unsizedtype_local
(meta.adlevel, ut) pp_expr expr
| FunApp
( StanLib (op, _, _)
, [ { meta= {type_= URowVector; _}
84 changes: 0 additions & 84 deletions test/integration/bad/algebra_solver/stanc.expected
Original file line number Diff line number Diff line change
@@ -125,48 +125,6 @@ Available signatures:
data real, data real) => vector
where F2 = (vector, vector, data array[] real, data array[] int) => vector
The 5th argument must be array[] int but got array[] real
(<F2>, vector, vector, data array[] real, data array[] int) => vector
Expected 5 arguments but found 8 arguments.
$ ../../../../../install/default/bin/stanc bad_newton_x_r_type.stan
Semantic error in 'bad_newton_x_r_type.stan', line 31, column 10 to column 69:
-------------------------------------------------
29: transformed parameters {
30: vector[2] y_s_p;
31: y_s_p = algebra_solver_newton(algebra_system, y, theta_p, x_r, x_i);
^
32: }
33:
-------------------------------------------------

Ill-typed arguments supplied to function 'algebra_solver_newton':
(<F1>, vector, vector, array[] int, array[] int)
where F1 = (vector, vector, array[] real, array[] int) => vector
Available signatures:
(<F2>, vector, vector, data array[] real, data array[] int) => vector
where F2 = (vector, vector, data array[] real, data array[] int) => vector
The fourth argument must be array[] real but got array[] int
(<F2>, vector, vector, data array[] real, data array[] int, data real,
data real, data real) => vector
Expected 8 arguments but found 5 arguments.
$ ../../../../../install/default/bin/stanc bad_newton_x_r_type_control.stan
Semantic error in 'bad_newton_x_r_type_control.stan', line 31, column 10 to column 85:
-------------------------------------------------
29: transformed parameters {
30: vector[2] y_s_p;
31: y_s_p = algebra_solver_newton(algebra_system, y, theta_p, x_r, x_i, 0.01, 0.01, 10);
^
32: }
33:
-------------------------------------------------

Ill-typed arguments supplied to function 'algebra_solver_newton':
(<F1>, vector, vector, array[] int, array[] int, real, real, int)
where F1 = (vector, vector, array[] real, array[] int) => vector
Available signatures:
(<F2>, vector, vector, data array[] real, data array[] int, data real,
data real, data real) => vector
where F2 = (vector, vector, data array[] real, data array[] int) => vector
The fourth argument must be array[] real but got array[] int
(<F2>, vector, vector, data array[] real, data array[] int) => vector
Expected 5 arguments but found 8 arguments.
$ ../../../../../install/default/bin/stanc bad_newton_x_r_var_type.stan
@@ -371,48 +329,6 @@ Available signatures:
data real, data real) => vector
where F2 = (vector, vector, data array[] real, data array[] int) => vector
The 5th argument must be array[] int but got array[] real
(<F2>, vector, vector, data array[] real, data array[] int) => vector
Expected 5 arguments but found 8 arguments.
$ ../../../../../install/default/bin/stanc bad_x_r_type.stan
Semantic error in 'bad_x_r_type.stan', line 31, column 10 to column 62:
-------------------------------------------------
29: transformed parameters {
30: vector[2] y_s_p;
31: y_s_p = algebra_solver(algebra_system, y, theta_p, x_r, x_i);
^
32: }
33:
-------------------------------------------------

Ill-typed arguments supplied to function 'algebra_solver':
(<F1>, vector, vector, array[] int, array[] int)
where F1 = (vector, vector, array[] real, array[] int) => vector
Available signatures:
(<F2>, vector, vector, data array[] real, data array[] int) => vector
where F2 = (vector, vector, data array[] real, data array[] int) => vector
The fourth argument must be array[] real but got array[] int
(<F2>, vector, vector, data array[] real, data array[] int, data real,
data real, data real) => vector
Expected 8 arguments but found 5 arguments.
$ ../../../../../install/default/bin/stanc bad_x_r_type_control.stan
Semantic error in 'bad_x_r_type_control.stan', line 31, column 10 to column 78:
-------------------------------------------------
29: transformed parameters {
30: vector[2] y_s_p;
31: y_s_p = algebra_solver(algebra_system, y, theta_p, x_r, x_i, 0.01, 0.01, 10);
^
32: }
33:
-------------------------------------------------

Ill-typed arguments supplied to function 'algebra_solver':
(<F1>, vector, vector, array[] int, array[] int, real, real, int)
where F1 = (vector, vector, array[] real, array[] int) => vector
Available signatures:
(<F2>, vector, vector, data array[] real, data array[] int, data real,
data real, data real) => vector
where F2 = (vector, vector, data array[] real, data array[] int) => vector
The fourth argument must be array[] real but got array[] int
(<F2>, vector, vector, data array[] real, data array[] int) => vector
Expected 5 arguments but found 8 arguments.
$ ../../../../../install/default/bin/stanc bad_x_r_var_type.stan
23 changes: 0 additions & 23 deletions test/integration/bad/stanc.expected
Original file line number Diff line number Diff line change
@@ -1779,29 +1779,6 @@ Semantic error in 'prob-poisson_log-trunc-low.stan', line 11, column 2 to column
-------------------------------------------------

Truncation is only defined if distribution has _lcdf and _lccdf functions implemented with appropriate signature.
$ ../../../../install/default/bin/stanc propto-bad1.stan
Semantic error in 'propto-bad1.stan', line 2, column 14 to column 37:
-------------------------------------------------
1: model {
2: target += normal_lupdf(1| {0}, 1);
^
3: }
-------------------------------------------------

Ill-typed arguments supplied to function 'normal_lupdf':
(int, array[] int, int)
Available signatures:
(real, real, real) => real
The second argument must be real but got array[] int
(real, real, array[] real) => real
The second argument must be real but got array[] int
(real, real, vector) => real
The second argument must be real but got array[] int
(real, real, row_vector) => real
The second argument must be real but got array[] int
(real, vector, real) => real
The second argument must be vector but got array[] int
(Additional signatures omitted)
$ ../../../../install/default/bin/stanc real-pdf.stan
Semantic error in 'real-pdf.stan', line 2, column 2 to line 4, column 3:
-------------------------------------------------
129 changes: 64 additions & 65 deletions test/integration/good/code-gen/complex_scalar.stan

Large diffs are not rendered by default.

420 changes: 420 additions & 0 deletions test/integration/good/code-gen/cpp.expected

Large diffs are not rendered by default.

72 changes: 60 additions & 12 deletions test/integration/good/code-gen/mir.expected
Original file line number Diff line number Diff line change
@@ -251,9 +251,17 @@
(((pattern (Var u))
(meta
((type_ UReal) (loc <opaque>) (adlevel AutoDiffable))))
((pattern (Lit Int 0))
((pattern
(Promotion
((pattern (Lit Int 0))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Int 1))
((pattern
(Promotion
((pattern (Lit Int 1))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta ((type_ UReal) (loc <opaque>) (adlevel AutoDiffable))))))
(meta <opaque>))
@@ -4827,7 +4835,11 @@
(FunApp (StanLib diag_matrix FnPlain AoS)
(((pattern
(FunApp (StanLib rep_vector FnPlain AoS)
(((pattern (Lit Int 1))
(((pattern
(Promotion
((pattern (Lit Int 1))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
((pattern
(FunApp (StanLib rows FnPlain AoS)
@@ -4843,7 +4855,11 @@
(FunApp (StanLib diag_matrix FnPlain AoS)
(((pattern
(FunApp (StanLib rep_vector FnPlain AoS)
(((pattern (Lit Int 1))
(((pattern
(Promotion
((pattern (Lit Int 1))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
((pattern
(FunApp (StanLib rows FnPlain AoS)
@@ -8735,7 +8751,11 @@
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Real 0.01))
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Int 10))
((pattern
(Promotion
((pattern (Lit Int 10))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta ((type_ UVector) (loc <opaque>) (adlevel AutoDiffable))))))
(meta <opaque>))
@@ -8763,7 +8783,11 @@
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Real 0.01))
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Int 10))
((pattern
(Promotion
((pattern (Lit Int 10))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta ((type_ UVector) (loc <opaque>) (adlevel AutoDiffable))))))
(meta <opaque>))
@@ -8813,7 +8837,11 @@
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Real 0.01))
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Int 10))
((pattern
(Promotion
((pattern (Lit Int 10))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta ((type_ UVector) (loc <opaque>) (adlevel AutoDiffable))))))
(meta <opaque>))
@@ -8863,7 +8891,11 @@
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Real 0.01))
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Int 10))
((pattern
(Promotion
((pattern (Lit Int 10))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta ((type_ UVector) (loc <opaque>) (adlevel AutoDiffable))))))
(meta <opaque>))
@@ -10969,7 +11001,11 @@
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Real 0.01))
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Int 10))
((pattern
(Promotion
((pattern (Lit Int 10))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta ((type_ UVector) (loc <opaque>) (adlevel AutoDiffable))))))
(meta <opaque>))
@@ -10997,7 +11033,11 @@
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Real 0.01))
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Int 10))
((pattern
(Promotion
((pattern (Lit Int 10))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta ((type_ UVector) (loc <opaque>) (adlevel AutoDiffable))))))
(meta <opaque>))
@@ -11047,7 +11087,11 @@
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Real 0.01))
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Int 10))
((pattern
(Promotion
((pattern (Lit Int 10))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta ((type_ UVector) (loc <opaque>) (adlevel AutoDiffable))))))
(meta <opaque>))
@@ -11097,7 +11141,11 @@
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Real 0.01))
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Int 10))
((pattern
(Promotion
((pattern (Lit Int 10))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta ((type_ UVector) (loc <opaque>) (adlevel AutoDiffable))))))
(meta <opaque>))
16 changes: 16 additions & 0 deletions test/integration/good/code-gen/promotion.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
functions {
complex ident (complex x){
return x;
}

real foo(array[] real zs){
return zs[0];
}
}

generated quantities {
real z = 1;
complex zi = ident(z);
array[4] int zs = {1,2,3,4};
real x = foo(zs);
}
72 changes: 60 additions & 12 deletions test/integration/good/code-gen/transformed_mir.expected
Original file line number Diff line number Diff line change
@@ -254,9 +254,17 @@
(((pattern (Var u))
(meta
((type_ UReal) (loc <opaque>) (adlevel AutoDiffable))))
((pattern (Lit Int 0))
((pattern
(Promotion
((pattern (Lit Int 0))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Int 1))
((pattern
(Promotion
((pattern (Lit Int 1))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta ((type_ UReal) (loc <opaque>) (adlevel AutoDiffable))))))
(meta <opaque>))
@@ -8234,7 +8242,11 @@
(FunApp (StanLib diag_matrix FnPlain AoS)
(((pattern
(FunApp (StanLib rep_vector FnPlain AoS)
(((pattern (Lit Int 1))
(((pattern
(Promotion
((pattern (Lit Int 1))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
((pattern
(FunApp (StanLib rows FnPlain AoS)
@@ -8250,7 +8262,11 @@
(FunApp (StanLib diag_matrix FnPlain AoS)
(((pattern
(FunApp (StanLib rep_vector FnPlain AoS)
(((pattern (Lit Int 1))
(((pattern
(Promotion
((pattern (Lit Int 1))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
((pattern
(FunApp (StanLib rows FnPlain AoS)
@@ -12542,7 +12558,11 @@
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Real 0.01))
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Int 10))
((pattern
(Promotion
((pattern (Lit Int 10))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta ((type_ UVector) (loc <opaque>) (adlevel AutoDiffable))))))
(meta <opaque>))
@@ -12570,7 +12590,11 @@
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Real 0.01))
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Int 10))
((pattern
(Promotion
((pattern (Lit Int 10))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta ((type_ UVector) (loc <opaque>) (adlevel AutoDiffable))))))
(meta <opaque>))
@@ -12620,7 +12644,11 @@
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Real 0.01))
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Int 10))
((pattern
(Promotion
((pattern (Lit Int 10))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta ((type_ UVector) (loc <opaque>) (adlevel AutoDiffable))))))
(meta <opaque>))
@@ -12670,7 +12698,11 @@
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Real 0.01))
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Int 10))
((pattern
(Promotion
((pattern (Lit Int 10))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta ((type_ UVector) (loc <opaque>) (adlevel AutoDiffable))))))
(meta <opaque>))
@@ -16218,7 +16250,11 @@
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Real 0.01))
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Int 10))
((pattern
(Promotion
((pattern (Lit Int 10))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta ((type_ UVector) (loc <opaque>) (adlevel AutoDiffable))))))
(meta <opaque>))
@@ -16246,7 +16282,11 @@
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Real 0.01))
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Int 10))
((pattern
(Promotion
((pattern (Lit Int 10))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta ((type_ UVector) (loc <opaque>) (adlevel AutoDiffable))))))
(meta <opaque>))
@@ -16296,7 +16336,11 @@
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Real 0.01))
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Int 10))
((pattern
(Promotion
((pattern (Lit Int 10))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta ((type_ UVector) (loc <opaque>) (adlevel AutoDiffable))))))
(meta <opaque>))
@@ -16346,7 +16390,11 @@
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Real 0.01))
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
((pattern (Lit Int 10))
((pattern
(Promotion
((pattern (Lit Int 10))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))
UReal))
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly)))))))
(meta ((type_ UVector) (loc <opaque>) (adlevel AutoDiffable))))))
(meta <opaque>))
320 changes: 167 additions & 153 deletions test/integration/good/compiler-optimizations/cpp.expected

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -91,11 +91,10 @@ generated quantities {
gq_r = foo5(gq_complex_array);
gq_complex_array = foo6(gq_r);
gq_complex_array = foo7(gq_complex_array);

// test imaginary literal
complex zi = 1 + 3.14i;
zi = zi * 0i;
complex yi = to_complex(0, 1.1) + to_complex(0.0, 2.2) + to_complex();
real x = get_real(3i - 40e-3i);
}

102 changes: 102 additions & 0 deletions test/integration/good/lang/pretty.expected
Original file line number Diff line number Diff line change
@@ -1,3 +1,105 @@
$ ../../../../../install/default/bin/stanc --auto-format good_complex.stan
functions {
complex foo() {
return to_complex();
}
real foo1(complex z) {
return 1.0;
}
complex foo2(real r) {
return to_complex(r);
}
complex foo3(complex z) {
return z;
}
array[] complex foo4() {
return {to_complex(), to_complex()};
}
real foo5(array[] complex z) {
return 1.0;
}
array[] complex foo6(real r) {
return {to_complex(r), to_complex(r, r)};
}
array[] complex foo7(array[] complex z) {
return z;
}
}
data {
int<lower=0> N;
complex d_complex;
array[N] complex d_complex_array;
array[N, 2] complex d_complex_array_2d;
array[N, 2, 3] complex d_complex_array_3d;
}
transformed data {
complex td_complex = d_complex;
td_complex = 1;
td_complex = to_complex(1, 2);
array[N] complex td_complex_array = d_complex_array;
array[2, 3] complex td_complex_array_2d;
array[2, 3, 4] complex td_complex_array_3d;
for (td_i in 1 : 2) {
for (td_j in 1 : 3) {
for (td_k in 1 : 4) {
td_complex_array_3d[td_i, td_j, td_k] = to_complex(1, 2.2);
}
}
}
real td_r = get_real(td_complex);
real td_i = get_imag(td_complex);
}
parameters {
complex p_complex;
array[N] complex p_complex_array;
array[N, 2] complex p_complex_array_2d;
array[N, 2, 3] complex p_complex_array_3d;
}
transformed parameters {
complex tp_complex = p_complex;
tp_complex = 1;
tp_complex = to_complex(1, 2);
array[N] complex tp_complex_array = p_complex_array;
array[2, 3] complex tp_complex_array_2d;
array[2, 3, 4] complex tp_complex_array_3d;
for (tp_i in 1 : 2) {
for (tp_j in 1 : 3) {
for (tp_k in 1 : 4) {
tp_complex_array_3d[tp_i, tp_j, tp_k] = to_complex(1, 2.2);
}
}
}
real tp_r = get_real(tp_complex);
real tp_i = get_imag(tp_complex);
}
model {
abs(p_complex) ~ normal(0, 1);
}
generated quantities {
complex gq_complex;
array[2] complex gq_complex_array;
array[2, 3] complex gq_complex_array_2d = {{1, 2, 3},
{to_complex(), to_complex(
1.1), to_complex(1, 2.1)}};
array[2, 3, 4] complex gq_complex_array_3d;
gq_complex = 3;
gq_complex_array[1] = to_complex(5.1, 6);
gq_complex = foo();
real gq_r = foo1(gq_complex);
gq_complex = foo2(gq_r);
gq_complex = foo3(gq_complex);
gq_complex_array = foo4();
gq_r = foo5(gq_complex_array);
gq_complex_array = foo6(gq_r);
gq_complex_array = foo7(gq_complex_array);

// test imaginary literal
complex zi = 1 + 3.14i;
zi = zi * 0i;
complex yi = to_complex(0, 1.1) + to_complex(0.0, 2.2) + to_complex();
real x = get_real(3i - 40e-3i);
}

$ ../../../../../install/default/bin/stanc --auto-format good_const.stan
transformed data {
real y;
102 changes: 0 additions & 102 deletions test/integration/good/parser-generator/pretty.expected
Original file line number Diff line number Diff line change
@@ -60,108 +60,6 @@ transformed parameters {
cholesky_factor_cov[3] cfcov_33;
}

$ ../../../../../install/default/bin/stanc --auto-format complex_scalar.stan
functions {
complex foo() {
return to_complex();
}
real foo1(complex z) {
return 1.0;
}
complex foo2(real r) {
return to_complex(r);
}
complex foo3(complex z) {
return z;
}
array[] complex foo4() {
return {to_complex(), to_complex()};
}
real foo5(array[] complex z) {
return 1.0;
}
array[] complex foo6(real r) {
return {to_complex(r), to_complex(r, r)};
}
array[] complex foo7(array[] complex z) {
return z;
}
}
data {
int<lower=0> N;
complex d_complex;
array[N] complex d_complex_array;
array[N, 2] complex d_complex_array_2d;
array[N, 2, 3] complex d_complex_array_3d;
}
transformed data {
complex td_complex = d_complex;
td_complex = 1;
td_complex = to_complex(1, 2);
array[N] complex td_complex_array = d_complex_array;
array[2, 3] complex td_complex_array_2d;
array[2, 3, 4] complex td_complex_array_3d;
for (td_i in 1 : 2) {
for (td_j in 1 : 3) {
for (td_k in 1 : 4) {
td_complex_array_3d[td_i, td_j, td_k] = to_complex(1, 2.2);
}
}
}
real td_r = get_real(td_complex);
real td_i = get_imag(td_complex);
}
parameters {
complex p_complex;
array[N] complex p_complex_array;
array[N, 2] complex p_complex_array_2d;
array[N, 2, 3] complex p_complex_array_3d;
}
transformed parameters {
complex tp_complex = p_complex;
tp_complex = 1;
tp_complex = to_complex(1, 2);
array[N] complex tp_complex_array = p_complex_array;
array[2, 3] complex tp_complex_array_2d;
array[2, 3, 4] complex tp_complex_array_3d;
for (tp_i in 1 : 2) {
for (tp_j in 1 : 3) {
for (tp_k in 1 : 4) {
tp_complex_array_3d[tp_i, tp_j, tp_k] = to_complex(1, 2.2);
}
}
}
real tp_r = get_real(tp_complex);
real tp_i = get_imag(tp_complex);
}
model {
abs(p_complex) ~ normal(0, 1);
}
generated quantities {
complex gq_complex;
array[2] complex gq_complex_array;
array[2, 3] complex gq_complex_array_2d = {{1, 2, 3},
{to_complex(), to_complex(
1.1), to_complex(1, 2.1)}};
array[2, 3, 4] complex gq_complex_array_3d;
gq_complex = 3;
gq_complex_array[1] = to_complex(5.1, 6);
gq_complex = foo();
real gq_r = foo1(gq_complex);
gq_complex = foo2(gq_r);
gq_complex = foo3(gq_complex);
gq_complex_array = foo4();
gq_r = foo5(gq_complex_array);
gq_complex_array = foo6(gq_r);
gq_complex_array = foo7(gq_complex_array);

// test imaginary literal
complex zi = 1 + 3.14i;
zi = zi * 0i;
complex yi = to_complex(0, 1.1) + to_complex(0.0, 2.2) + to_complex();
real x = get_real(3i - 40e-3i);
}

$ ../../../../../install/default/bin/stanc --auto-format data_qualifier_scope_xformed_data.stan
functions {
vector target_(vector y, vector theta, array[] real x_r, array[] int x_i) {
10 changes: 10 additions & 0 deletions test/integration/good/promotion/array_promotion.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
functions {
void printer(array[] real x){
print(x);
}
}

model {
array[3] int d = {1,2,3};
printer(d);
}
32 changes: 32 additions & 0 deletions test/integration/good/promotion/complex_functions.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
functions {
void promote_complex_array(array[] complex zs){
print(zs[0]);
}
}

generated quantities {
real x = norm(1);
x = norm(1.5);
x = norm(3i);

real y = abs(4+3i);
y = arg(4+1i);
y = arg(2.5);
y = arg(1);

complex z;
z = conj(4.1+7i);
z = conj(4.1);
z = conj(0);

z = proj(4.1+7i);
z = proj(4.1);
z = proj(0);

z = polar(1.5,0.5);
z = polar(2,3);

array[3] int xs = {1,2,3};
promote_complex_array(xs);

}
1 change: 1 addition & 0 deletions test/integration/good/promotion/dune
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
(include ../dune)
170 changes: 170 additions & 0 deletions test/integration/good/promotion/pretty.expected
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
$ ../../../../../install/default/bin/stanc --auto-format array_promotion.stan
functions {
void printer(array[] real x) {
print(x);
}
}
model {
array[3] int d = {1, 2, 3};
printer(d);
}

$ ../../../../../install/default/bin/stanc --auto-format bad_newton_x_r_type.stan
functions {
vector algebra_system(vector y, vector theta, array[] real x_r,
array[] int x_i) {
vector[2] f_y;
f_y[1] = y[1] - theta[1];
f_y[2] = y[2] - theta[2];
return f_y;
}
}
data {

}
transformed data {
vector[2] y;
array[0] int x_r;
array[0] int x_i;
}
parameters {
vector[2] theta_p;
real dummy_parameter;
}
transformed parameters {
vector[2] y_s_p;
y_s_p = algebra_solver_newton(algebra_system, y, theta_p, x_r, x_i);
}
model {
dummy_parameter ~ normal(0, 1);
}

$ ../../../../../install/default/bin/stanc --auto-format bad_newton_x_r_type_control.stan
functions {
vector algebra_system(vector y, vector theta, array[] real x_r,
array[] int x_i) {
vector[2] f_y;
f_y[1] = y[1] - theta[1];
f_y[2] = y[2] - theta[2];
return f_y;
}
}
data {

}
transformed data {
vector[2] y;
array[0] int x_r;
array[0] int x_i;
}
parameters {
vector[2] theta_p;
real dummy_parameter;
}
transformed parameters {
vector[2] y_s_p;
y_s_p = algebra_solver_newton(algebra_system, y, theta_p, x_r, x_i, 0.01,
0.01, 10);
}
model {
dummy_parameter ~ normal(0, 1);
}

$ ../../../../../install/default/bin/stanc --auto-format bad_x_r_type.stan
functions {
vector algebra_system(vector y, vector theta, array[] real x_r,
array[] int x_i) {
vector[2] f_y;
f_y[1] = y[1] - theta[1];
f_y[2] = y[2] - theta[2];
return f_y;
}
}
data {

}
transformed data {
vector[2] y;
array[0] int x_r;
array[0] int x_i;
}
parameters {
vector[2] theta_p;
real dummy_parameter;
}
transformed parameters {
vector[2] y_s_p;
y_s_p = algebra_solver(algebra_system, y, theta_p, x_r, x_i);
}
model {
dummy_parameter ~ normal(0, 1);
}

$ ../../../../../install/default/bin/stanc --auto-format bad_x_r_type_control.stan
functions {
vector algebra_system(vector y, vector theta, array[] real x_r,
array[] int x_i) {
vector[2] f_y;
f_y[1] = y[1] - theta[1];
f_y[2] = y[2] - theta[2];
return f_y;
}
}
data {

}
transformed data {
vector[2] y;
array[0] int x_r;
array[0] int x_i;
}
parameters {
vector[2] theta_p;
real dummy_parameter;
}
transformed parameters {
vector[2] y_s_p;
y_s_p = algebra_solver(algebra_system, y, theta_p, x_r, x_i, 0.01, 0.01,
10);
}
model {
dummy_parameter ~ normal(0, 1);
}

$ ../../../../../install/default/bin/stanc --auto-format complex_functions.stan
functions {
void promote_complex_array(array[] complex zs) {
print(zs[0]);
}
}
generated quantities {
real x = norm(1);
x = norm(1.5);
x = norm(3i);

real y = abs(4 + 3i);
y = arg(4 + 1i);
y = arg(2.5);
y = arg(1);

complex z;
z = conj(4.1 + 7i);
z = conj(4.1);
z = conj(0);

z = proj(4.1 + 7i);
z = proj(4.1);
z = proj(0);

z = polar(1.5, 0.5);
z = polar(2, 3);

array[3] int xs = {1, 2, 3};
promote_complex_array(xs);
}

$ ../../../../../install/default/bin/stanc --auto-format propto-bad1.stan
model {
target += normal_lupdf(1 | {0}, 1);
}

File renamed without changes.
4 changes: 2 additions & 2 deletions test/unit/Optimize.ml
Original file line number Diff line number Diff line change
@@ -1964,7 +1964,7 @@ let%expect_test "partial evaluation" =
int i;
FnPrint__(3);
FnPrint__((i + 3));
FnPrint__(log1m(i));
FnPrint__(log((1 - i) -> real));
}
}
}
@@ -2360,7 +2360,7 @@ model {
target += sqrt(34.);
target += sqrt(34.);
target += variance(x_vector);
target += sqrt2();
target += sqrt(2 -> real);
target += squared_distance(x_vector, y_vector);
target += trace(x_matrix);
target += trace_gen_quad_form(x_matrix, z_matrix, y_matrix);