Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
66bf994
Convert reads to constrained reads; remove constraints
rybern Mar 12, 2021
8623456
Handling constraint argument codegen correctly (with hack)
rybern Mar 12, 2021
6ce419d
remove old comments
rybern Mar 12, 2021
9f6836b
merge master
rybern Mar 12, 2021
d045c45
Change in__ from reader to deserializer; remove jacobian from read() …
rybern Mar 12, 2021
03e789e
Don't include lp__ in unconstrainted reads
rybern Mar 12, 2021
ce092c3
merge master
rybern Mar 17, 2021
c28d4ab
update to deserializer interface
rybern Mar 17, 2021
7a601ac
change reader to deserializer in printed C++ code
SteveBronder Mar 18, 2021
0e0fecb
use template keyword when calling read_constrain functions, add a con…
SteveBronder Mar 18, 2021
3eaab9e
add little func for getting the weird constrain dimensions for matrices
SteveBronder Mar 18, 2021
1716dff
Fix bug in pedantic mode brought up by constraint changes
rybern Mar 18, 2021
5191df7
formatting
rybern Mar 18, 2021
fae0929
add template keyword to read for deserializer
SteveBronder Mar 19, 2021
4c9fdaf
add template keyword to read for deserializer
SteveBronder Mar 19, 2021
9d81520
fix typo for cholesky_factor_cov
SteveBronder Mar 19, 2021
d4895b1
remove the need for assigns
rok-cesnovar Mar 19, 2021
7b9cf1c
make format
SteveBronder Mar 19, 2021
4a846cf
turn down the number of parallel jobs for compilation
SteveBronder Mar 20, 2021
c76af48
Merge remote-tracking branch 'upstream/master' into HEAD
SteveBronder Mar 21, 2021
854c78e
update to master
SteveBronder Mar 21, 2021
e66e7db
Merge remote-tracking branch 'upstream/master' into HEAD
SteveBronder Mar 29, 2021
aa5c9bc
Small code cleanup
Mar 29, 2021
75eba6a
format
rybern Mar 29, 2021
ef05433
dune promote
SteveBronder Mar 29, 2021
2a531d1
First pass at function variant types
rybern Mar 29, 2021
53115a7
update Identity for constraints to be blank to match the deserializer…
SteveBronder Mar 30, 2021
b4ad3ae
Function variant types might be working
rybern Mar 30, 2021
4696379
Merge branch 'constraint-refactor-2' of github.com:rybern/stanc3 into…
rybern Mar 30, 2021
5591542
promote tests
rybern Mar 30, 2021
87ba3f1
Improved FnReadParam hack
rybern Mar 30, 2021
9d511d5
Use option type for FnReadParam transformation
rybern Mar 30, 2021
af670fa
Unify NRFunApp with FunApp monotone free vars
seantalts Apr 2, 2021
171edb3
Make Fun_kind.t non-parametric
seantalts Apr 2, 2021
21154d4
tiny reformat
seantalts Apr 2, 2021
38ba10b
Merge branch 'master' of github.com:stan-dev/stanc3 into constraint-r…
seantalts Apr 3, 2021
b8e666e
Promote transformed MIR expect tests
seantalts Apr 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import org.stan.Utils

def utils = new org.stan.Utils()
def skipExpressionTests = false

/* Functions that runs a sh command and returns the stdout */
def runShell(String command){
def output = sh (returnStdout: true, script: "${command}").trim()
Expand Down
8 changes: 4 additions & 4 deletions src/analysis_and_optimization/Factor_graph.ml
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ let extract_factors_statement stmt =
match stmt with
| Stmt.Fixed.Pattern.TargetPE e ->
List.map (summation_terms e) ~f:(fun x -> TargetTerm x)
| NRFunApp (_, f, _) when Internal_fun.of_string_opt f = Some FnReject ->
[Reject]
| NRFunApp (_, s, args) when String.suffix s 3 = "_lp" ->
| NRFunApp (CompilerInternal FnReject, _) -> [Reject]
| NRFunApp ((UserDefined s | StanLib s), args) when String.suffix s 3 = "_lp"
->
[LPFunction (s, args)]
| Assignment (_, _)
|NRFunApp (_, _, _)
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip
|IfElse (_, _, _)
|While (_, _)
Expand Down
14 changes: 7 additions & 7 deletions src/analysis_and_optimization/Mir_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ let rec num_expr_value (v : Expr.Typed.t) : (float * string) option =
| {pattern= Fixed.Pattern.Lit (Real, str); _}
|{pattern= Fixed.Pattern.Lit (Int, str); _} ->
Some (float_of_string str, str)
| {pattern= Fixed.Pattern.FunApp (StanLib, "PMinus__", [v]); _} -> (
| {pattern= Fixed.Pattern.FunApp (StanLib "PMinus__", [v]); _} -> (
match num_expr_value v with
| Some (v, s) -> Some (-.v, "-" ^ s)
| None -> None )
Expand Down Expand Up @@ -252,7 +252,7 @@ let rec expr_var_set Expr.Fixed.({pattern; meta}) =
match pattern with
| Var s -> Set.Poly.singleton (VVar s, meta)
| Lit _ -> Set.Poly.empty
| FunApp (_, _, exprs) -> union_recur exprs
| FunApp (_, exprs) -> union_recur exprs
| TernaryIf (expr1, expr2, expr3) -> union_recur [expr1; expr2; expr3]
| Indexed (expr, ix) ->
Set.Poly.union_list (expr_var_set expr :: List.map ix ~f:index_var_set)
Expand All @@ -270,7 +270,7 @@ and index_var_set ix =
let stmt_rhs stmt =
match stmt with
| Stmt.Fixed.Pattern.For vars -> Set.Poly.of_list [vars.lower; vars.upper]
| NRFunApp (_, _, exprs) -> Set.Poly.of_list exprs
| NRFunApp (_, exprs) -> Set.Poly.of_list exprs
| IfElse (rhs, _, _)
|While (rhs, _)
|Assignment (_, rhs)
Expand All @@ -296,7 +296,7 @@ let expr_assigned_var Expr.Fixed.({pattern; _}) =
(** See interface file *)
let rec summation_terms (Expr.Fixed.({pattern; _}) as rhs) =
match pattern with
| FunApp (_, "Plus__", [e1; e2]) ->
| FunApp (StanLib "Plus__", [e1; e2]) ->
List.append (summation_terms e1) (summation_terms e2)
| _ -> [rhs]

Expand Down Expand Up @@ -356,7 +356,7 @@ let expr_subst_stmt m = map_rec_stmt_loc (expr_subst_stmt_base m)
let rec expr_depth Expr.Fixed.({pattern; _}) =
match pattern with
| Var _ | Lit (_, _) -> 0
| FunApp (_, _, l) ->
| FunApp (_, l) ->
1
+ Option.value ~default:0
(List.max_elt ~compare:compare_int (List.map ~f:expr_depth l))
Expand Down Expand Up @@ -394,9 +394,9 @@ let rec update_expr_ad_levels autodiffable_variables
Expr.Typed.{e with meta= Meta.{e.meta with adlevel= AutoDiffable}}
else {e with meta= {e.meta with adlevel= DataOnly}}
| Lit (_, _) -> {e with meta= {e.meta with adlevel= DataOnly}}
| FunApp (o, f, l) ->
| FunApp (kind, l) ->
let l = List.map ~f:(update_expr_ad_levels autodiffable_variables) l in
{pattern= FunApp (o, f, l); meta= {e.meta with adlevel= ad_level_sup l}}
{pattern= FunApp (kind, l); meta= {e.meta with adlevel= ad_level_sup l}}
| TernaryIf (e1, e2, e3) ->
let e1 = update_expr_ad_levels autodiffable_variables e1 in
let e2 = update_expr_ad_levels autodiffable_variables e2 in
Expand Down
38 changes: 22 additions & 16 deletions src/analysis_and_optimization/Monotone_framework.ml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ let rec free_vars_expr (e : Expr.Typed.t) =
match e.pattern with
| Var x -> Set.Poly.singleton x
| Lit (_, _) -> Set.Poly.empty
| FunApp (_, f, l) ->
Set.Poly.union_list (Set.Poly.singleton f :: List.map ~f:free_vars_expr l)
| FunApp (kind, l) -> free_vars_fnapp kind l
| TernaryIf (e1, e2, e3) ->
Set.Poly.union_list (List.map ~f:free_vars_expr [e1; e2; e3])
| Indexed (e, l) ->
Expand All @@ -45,6 +44,13 @@ and free_vars_idx (i : Expr.Typed.t Index.t) =
| Single e | Upfrom e | MultiIndex e -> free_vars_expr e
| Between (e1, e2) -> Set.Poly.union (free_vars_expr e1) (free_vars_expr e2)

and free_vars_fnapp kind l =
let arg_vars = List.map ~f:free_vars_expr l in
match kind with
| Fun_kind.UserDefined f ->
Set.Poly.union_list (Set.Poly.singleton f :: List.map ~f:free_vars_expr l)
| _ -> Set.Poly.union_list arg_vars

(** Calculate the free (non-bound) variables in a statement *)
let rec free_vars_stmt
(s : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t) =
Expand All @@ -53,8 +59,7 @@ let rec free_vars_stmt
free_vars_expr e
| Assignment ((_, _, l), e) ->
Set.Poly.union_list (free_vars_expr e :: List.map ~f:free_vars_idx l)
| NRFunApp (_, f, l) ->
Set.Poly.union_list (Set.Poly.singleton f :: List.map ~f:free_vars_expr l)
| NRFunApp (kind, l) -> free_vars_fnapp kind l
| IfElse (e, b1, Some b2) ->
Set.Poly.union_list
[free_vars_expr e; free_vars_stmt b1.pattern; free_vars_stmt b2.pattern]
Expand Down Expand Up @@ -314,7 +319,7 @@ let constant_propagation_transfer
| Decl {decl_id= s; _} | Assignment ((s, _, _ :: _), _) ->
Map.remove m s
| TargetPE _
|NRFunApp (_, _, _)
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip
|IfElse (_, _, _)
|While (_, _)
Expand Down Expand Up @@ -373,7 +378,7 @@ let expression_propagation_transfer
in
Set.Poly.fold kills ~init:m ~f:kill_var
| TargetPE _
|NRFunApp (_, _, _)
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip
|IfElse (_, _, _)
|While (_, _)
Expand Down Expand Up @@ -414,7 +419,7 @@ let copy_propagation_transfer (globals : string Set.Poly.t)
in
Set.Poly.fold kills ~init:m ~f:kill_var
| TargetPE _
|NRFunApp (_, _, _)
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip
|IfElse (_, _, _)
|While (_, _)
Expand All @@ -435,11 +440,11 @@ let assigned_vars_stmt (s : (Expr.Typed.t, 'a) Stmt.Fixed.Pattern.t) =
match s with
| Assignment ((x, _, _), _) -> Set.Poly.singleton x
| TargetPE _ -> Set.Poly.singleton "target"
| NRFunApp (_, s, _) when String.suffix s 3 = "_lp" ->
| NRFunApp ((UserDefined s | StanLib s), _) when String.suffix s 3 = "_lp" ->
Set.Poly.singleton "target"
| For {loopvar= x; _} -> Set.Poly.singleton x
| Decl {decl_id= _; _}
|NRFunApp (_, _, _)
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip
|IfElse (_, _, _)
|While (_, _)
Expand Down Expand Up @@ -478,9 +483,10 @@ let reaching_definitions_transfer
|For {loopvar= x; _} ->
Set.filter p ~f:(fun (y, _) -> y = x)
| TargetPE _ -> Set.filter p ~f:(fun (y, _) -> y = "target")
| NRFunApp (_, s, _) when String.suffix s 3 = "_lp" ->
| NRFunApp ((UserDefined s | StanLib s), _)
when String.suffix s 3 = "_lp" ->
Set.filter p ~f:(fun (y, _) -> y = "target")
| NRFunApp (_, _, _)
| NRFunApp (_, _)
|Break | Continue | Return _ | Skip
|IfElse (_, _, _)
|While (_, _)
Expand Down Expand Up @@ -523,7 +529,7 @@ let live_variables_transfer (never_kill : string Set.Poly.t)
| Assignment ((x, _, []), _) | Decl {decl_id= x; _} ->
Set.Poly.singleton x
| TargetPE _
|NRFunApp (_, _, _)
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip
|IfElse (_, _, _)
|While (_, _)
Expand All @@ -542,7 +548,7 @@ let rec used_subexpressions_expr (e : Expr.Typed.t) =
(Expr.Typed.Set.singleton e)
( match e.pattern with
| Var _ | Lit (_, _) -> Expr.Typed.Set.empty
| FunApp (_, _, l) ->
| FunApp (_, l) ->
Expr.Typed.Set.union_list (List.map ~f:used_subexpressions_expr l)
| TernaryIf (e1, e2, e3) ->
Expr.Typed.Set.union_list
Expand Down Expand Up @@ -580,7 +586,7 @@ let rec used_expressions_stmt_help f
[ f e
; used_expressions_stmt_help f b1.pattern
; used_expressions_stmt_help f b2.pattern ]
| NRFunApp (_, _, l) -> Expr.Typed.Set.union_list (List.map ~f l)
| NRFunApp (_, l) -> Expr.Typed.Set.union_list (List.map ~f l)
| Decl _ | Return None | Break | Continue | Skip -> Expr.Typed.Set.empty
| IfElse (e, b, None) | While (e, b) ->
Expr.Typed.Set.union (f e) (used_expressions_stmt_help f b.pattern)
Expand Down Expand Up @@ -614,7 +620,7 @@ let top_used_expressions_stmt_help f
(Expr.Typed.Set.union_list
(List.map ~f:(used_expressions_idx_help f) l))
| While (e, _) | IfElse (e, _, _) -> f e
| NRFunApp (_, _, l) -> Expr.Typed.Set.union_list (List.map ~f l)
| NRFunApp (_, l) -> Expr.Typed.Set.union_list (List.map ~f l)
| Profile _ | Block _ | SList _ | Decl _
|Return None
|Break | Continue | Skip ->
Expand Down Expand Up @@ -899,7 +905,7 @@ let rec declared_variables_stmt
| Decl {decl_id= x; _} -> Set.Poly.singleton x
| Assignment (_, _)
|TargetPE _
|NRFunApp (_, _, _)
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip ->
Set.Poly.empty
| IfElse (_, b1, Some b2) ->
Expand Down
114 changes: 63 additions & 51 deletions src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ let rec inline_function_expression propto adt fim
match pattern with
| Var _ -> ([], [], e)
| Lit (_, _) -> ([], [], e)
| FunApp (t, s, es) -> (
| FunApp (kind, es) -> (
let dse_list =
List.map ~f:(inline_function_expression propto adt fim) es
in
Expand All @@ -231,30 +231,45 @@ let rec inline_function_expression propto adt fim
List.concat (List.rev (List.map ~f:(function _, x, _ -> x) dse_list))
in
let es = List.map ~f:(function _, _, x -> x) dse_list in
let s = if propto then s else Middle.Utils.stdlib_distribution_name s in
match Map.find fim s with
| None -> (d_list, s_list, {e with pattern= FunApp (t, s, es)})
| Some (rt, args, b) ->
let x = Gensym.generate ~prefix:"inline_" () in
let handle = handle_early_returns (Some x) in
let d_list2, s_list2, (e : Expr.Typed.t) =
( [ Stmt.Fixed.Pattern.Decl
{decl_adtype= adt; decl_id= x; decl_type= Option.value_exn rt}
]
(* We should minimize the code that's having its variables
match kind with
| CompilerInternal _ ->
(d_list, s_list, {e with pattern= FunApp (kind, es)})
| UserDefined fname | StanLib fname -> (
let fname =
if propto then fname
else Middle.Utils.stdlib_distribution_name fname
in
match Map.find fim fname with
| None ->
let fun_kind =
match kind with
| Fun_kind.UserDefined _ -> Fun_kind.UserDefined fname
| _ -> StanLib fname
in
(d_list, s_list, {e with pattern= FunApp (fun_kind, es)})
| Some (rt, args, b) ->
let x = Gensym.generate ~prefix:"inline_" () in
let handle = handle_early_returns (Some x) in
let d_list2, s_list2, (e : Expr.Typed.t) =
( [ Stmt.Fixed.Pattern.Decl
{ decl_adtype= adt
; decl_id= x
; decl_type= Option.value_exn rt } ]
(* We should minimize the code that's having its variables
replaced to avoid conflict with the (two) new dummy
variables introduced by inlining *)
, [handle (replace_fresh_local_vars (subst_args_stmt args es b))]
, { pattern= Var x
; meta=
Expr.Typed.Meta.
{ type_= Type.to_unsized (Option.value_exn rt)
; adlevel= adt
; loc= Location_span.empty } } )
in
let d_list = d_list @ d_list2 in
let s_list = s_list @ s_list2 in
(d_list, s_list, e) )
, [ handle
(replace_fresh_local_vars (subst_args_stmt args es b)) ]
, { pattern= Var x
; meta=
Expr.Typed.Meta.
{ type_= Type.to_unsized (Option.value_exn rt)
; adlevel= adt
; loc= Location_span.empty } } )
in
let d_list = d_list @ d_list2 in
let s_list = s_list @ s_list2 in
(d_list, s_list, e) ) )
| TernaryIf (e1, e2, e3) ->
let dl1, sl1, e1 = inline_function_expression propto adt fim e1 in
let dl2, sl2, e2 = inline_function_expression propto adt fim e2 in
Expand Down Expand Up @@ -347,7 +362,7 @@ let rec inline_function_statement propto adt fim Stmt.Fixed.({pattern; meta}) =
| TargetPE e ->
let d, s, e = inline_function_expression propto adt fim e in
slist_concat_no_loc (d @ s) (TargetPE e)
| NRFunApp (t, s, es) ->
| NRFunApp (kind, es) ->
let dse_list =
List.map ~f:(inline_function_expression propto adt fim) es
in
Expand All @@ -362,14 +377,17 @@ let rec inline_function_statement propto adt fim Stmt.Fixed.({pattern; meta}) =
in
let es = List.map ~f:(function _, _, x -> x) dse_list in
slist_concat_no_loc (d_list @ s_list)
( match Map.find fim s with
| None -> NRFunApp (t, s, es)
| Some (_, args, b) ->
let b = replace_fresh_local_vars b in
let b = handle_early_returns None b in
(subst_args_stmt args es
{pattern= b; meta= Location_span.empty})
.pattern )
( match kind with
| CompilerInternal _ -> NRFunApp (kind, es)
| UserDefined s | StanLib s -> (
match Map.find fim s with
| None -> NRFunApp (kind, es)
| Some (_, args, b) ->
let b = replace_fresh_local_vars b in
let b = handle_early_returns None b in
(subst_args_stmt args es
{pattern= b; meta= Location_span.empty})
.pattern ) )
| Return e -> (
match e with
| None -> Return None
Expand Down Expand Up @@ -499,7 +517,7 @@ let rec contains_top_break_or_continue Stmt.Fixed.({pattern; _}) =
| Break | Continue -> true
| Assignment (_, _)
|TargetPE _
|NRFunApp (_, _, _)
|NRFunApp (_, _)
|Return _ | Decl _
|While (_, _)
|For _ | Skip ->
Expand Down Expand Up @@ -565,7 +583,7 @@ let unroll_loop_one_step_statement _ =
else
IfElse
( Expr.Fixed.
{lower with pattern= FunApp (StanLib, "Geq__", [upper; lower])}
{lower with pattern= FunApp (StanLib "Geq__", [upper; lower])}
, { pattern=
(let body_unrolled =
subst_args_stmt [loopvar] [lower]
Expand All @@ -581,8 +599,7 @@ let unroll_loop_one_step_statement _ =
{ lower with
pattern=
FunApp
( StanLib
, "Plus__"
( StanLib "Plus__"
, [lower; Expr.Helpers.loop_bottom] ) } }
; meta= Location_span.empty }
in
Expand Down Expand Up @@ -666,26 +683,21 @@ and accum_any pred b e = b || expr_any pred e

let can_side_effect_top_expr (e : Expr.Typed.t) =
match e.pattern with
| FunApp (t, f, _) ->
String.suffix f 3 = "_lp"
|| (t = CompilerInternal && f = Internal_fun.to_string FnReadParam)
|| (t = CompilerInternal && f = Internal_fun.to_string FnReadData)
|| (t = CompilerInternal && f = Internal_fun.to_string FnWriteParam)
|| (t = CompilerInternal && f = Internal_fun.to_string FnConstrain)
|| (t = CompilerInternal && f = Internal_fun.to_string FnValidateSize)
|| (t = CompilerInternal && f = Internal_fun.to_string FnValidateSize)
|| t = CompilerInternal
&& f = Internal_fun.to_string FnValidateSizeSimplex
|| t = CompilerInternal
&& f = Internal_fun.to_string FnValidateSizeUnitVector
|| (t = CompilerInternal && f = Internal_fun.to_string FnUnconstrain)
| FunApp ((UserDefined f | StanLib f), _) -> String.suffix f 3 = "_lp"
| FunApp
( CompilerInternal
( FnReadParam _ | FnReadData | FnWriteParam | FnConstrain _
| FnValidateSize | FnValidateSizeSimplex | FnValidateSizeUnitVector
| FnUnconstrain _ )
, _ ) ->
true
| _ -> false

let cannot_duplicate_expr (e : Expr.Typed.t) =
let pred e =
can_side_effect_top_expr e
|| ( match e.pattern with
| FunApp (_, f, _) -> String.suffix f 4 = "_rng"
| FunApp ((UserDefined f | StanLib f), _) -> String.suffix f 4 = "_rng"
| _ -> false )
|| (preserve_stability && UnsizedType.is_autodiffable e.meta.type_)
in
Expand Down Expand Up @@ -746,7 +758,7 @@ let dead_code_elimination (mir : Program.Typed.t) =
due to side effects. *)
(* TODO: maybe we should revisit that. *)
| Decl _ | TargetPE _
|NRFunApp (_, _, _)
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip ->
stmt
| IfElse (e, b1, b2) -> (
Expand Down
Loading