Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions src/stan_math_backend/Stan_math_code_gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -367,14 +367,22 @@ let pp_validate_data ppf (name, st) =

let pp_mul ppf () = pf ppf " * "

let get_param_st lst =
let get_unconstrained_param_st lst =
match lst with
| _, {Program.out_block= Parameters; out_unconstrained_st= st; _} -> (
match SizedType.get_dims_io st with
| [] -> Some [Expr.Helpers.loop_bottom]
| ls -> Some ls )
| _ -> None

let get_constrained_param_st lst =
match lst with
| _, {Program.out_block= Parameters; out_constrained_st= st; _} -> (
match SizedType.get_dims_io st with
| [] -> Some [Expr.Helpers.loop_bottom]
| ls -> Some ls )
| _ -> None

let pp_num_param ppf (dims : Expr.Typed.t list) =
match dims with
| [a] -> pf ppf "@[%a@]@," (list ~sep:pp_mul pp_expr) [a]
Expand Down Expand Up @@ -425,7 +433,9 @@ let pp_ctor ppf p =
pp_located_error ppf
(pp_block, (list ~sep:cut pp_stmt_topdecl_size_only, prepare_data)) ;
cut ppf () ;
let output_params = List.filter_map ~f:get_param_st output_vars in
let output_params =
List.filter_map ~f:get_unconstrained_param_st output_vars
in
let pp_plus ppf () = pf ppf " + " in
let pp_set_params ppf pars =
(list ~sep:pp_plus pp_num_param) ppf pars
Expand Down Expand Up @@ -857,37 +867,54 @@ let pp_transform_inits ppf {Program.output_vars; _} =
in
let param_names = List.filter_map ~f:list_names output_vars in
let list_len = List.length param_names in
let output_params = List.filter_map ~f:get_param_st output_vars in
let constrained_params =
List.filter_map ~f:get_constrained_param_st output_vars
in
let get_names ppf () =
let add_param = fmt "%S" in
pf ppf "@[<hov -1> constexpr std::array<const char*, %i> names__{%a};@,"
let add_param = fmt "%S@," in
pf ppf "@[<hov 2> constexpr std::array<const char*, %i> names__{%a};@]@,"
list_len
(list ~sep:comma add_param)
param_names
in
let get_sizes ppf () =
match output_params with
| [] -> pf ppf " const std::array<Eigen::Index, 0> num_params__{};@]@,"
let get_constrain_param_size_arr ppf () =
match constrained_params with
| [] ->
pf ppf
"@[<hov 2> const std::array<Eigen::Index, 0> \
constrain_param_sizes__{};@]@,"
| _ ->
let pp_set_params ppf pars = (list ~sep:comma pp_num_param) ppf pars in
pf ppf " const std::array<Eigen::Index, %i> num_params__{%a};@]@,"
list_len pp_set_params output_params
pf ppf
"@[<hov 2> const std::array<Eigen::Index, %i> \
constrain_param_sizes__{%a};@]@,"
list_len pp_set_params constrained_params
in
let get_constrained_param_size ppf () =
pf ppf
"@[<hov 2> const auto num_constrained_params__ = std::accumulate(@, \
constrain_param_sizes__.begin(),@,@ constrain_param_sizes__.end(), \
0);@]@,"
in
let pp_body ppf =
pf ppf "%a" (list ~sep:cut string)
[ " std::vector<double> params_r_flat__(num_params_r__);"
[ " std::vector<double> params_r_flat__(num_constrained_params__);"
; " Eigen::Index size_iter__ = 0;"; " Eigen::Index flat_iter__ = 0;"
; " for (auto&& param_name__ : names__) {"
; " const auto param_vec__ = context.vals_r(param_name__);"
; " for (Eigen::Index i = 0; i < num_params__[size_iter__]; ++i) {"
; " for (Eigen::Index i = 0; i < \
constrain_param_sizes__[size_iter__]; ++i) {"
; " params_r_flat__[flat_iter__] = param_vec__[i];"
; " ++flat_iter__;"; " }"; " ++size_iter__;"; " }"
; " vars.resize(params_r_flat__.size());"
; " vars.resize(num_params_r__);"
; " transform_inits_impl(params_r_flat__, params_i, vars, pstream__);" ]
in
let cv_attr = ["const"] in
let blah ppf = pf ppf "%a %a" get_names () get_sizes in
pp_method ppf "void" "transform_inits" params blah
let intro ppf =
pf ppf "%a %a %a" get_names () get_constrain_param_size_arr ()
get_constrained_param_size
in
pp_method ppf "void" "transform_inits" params intro
(fun ppf -> pp_body ppf)
~cv_attr

Expand Down
10 changes: 6 additions & 4 deletions test/integration/cli-args/filename_good.expected
Original file line number Diff line number Diff line change
Expand Up @@ -281,20 +281,22 @@ class filename_good_model final : public model_base_crtp<filename_good_model> {
std::vector<double>& vars,
std::ostream* pstream__ = nullptr) const {
constexpr std::array<const char*, 0> names__{};
const std::array<Eigen::Index, 0> num_params__{};
const std::array<Eigen::Index, 0> constrain_param_sizes__{};
const auto num_constrained_params__ = std::accumulate(
constrain_param_sizes__.begin(), constrain_param_sizes__.end(), 0);

std::vector<double> params_r_flat__(num_params_r__);
std::vector<double> params_r_flat__(num_constrained_params__);
Eigen::Index size_iter__ = 0;
Eigen::Index flat_iter__ = 0;
for (auto&& param_name__ : names__) {
const auto param_vec__ = context.vals_r(param_name__);
for (Eigen::Index i = 0; i < num_params__[size_iter__]; ++i) {
for (Eigen::Index i = 0; i < constrain_param_sizes__[size_iter__]; ++i) {
params_r_flat__[flat_iter__] = param_vec__[i];
++flat_iter__;
}
++size_iter__;
}
vars.resize(params_r_flat__.size());
vars.resize(num_params_r__);
transform_inits_impl(params_r_flat__, params_i, vars, pstream__);
} // transform_inits()

Expand Down
14 changes: 8 additions & 6 deletions test/integration/good/code-gen/cl.expected
Original file line number Diff line number Diff line change
Expand Up @@ -1617,22 +1617,24 @@ class optimize_glm_model final : public model_base_crtp<optimize_glm_model> {
std::vector<double>& vars,
std::ostream* pstream__ = nullptr) const {
constexpr std::array<const char*, 9> names__{"alpha_v", "beta", "cuts",
"sigma", "alpha", "phi", "X_p", "beta_m", "X_rv_p"};
const std::array<Eigen::Index, 9> num_params__{k, k, k, 1, 1, 1,
(n * k), (n * k), n};
"sigma", "alpha", "phi", "X_p", "beta_m", "X_rv_p"};
const std::array<Eigen::Index, 9> constrain_param_sizes__{k, k,
k, 1, 1, 1, (n * k), (n * k), n};
const auto num_constrained_params__ = std::accumulate(
constrain_param_sizes__.begin(), constrain_param_sizes__.end(), 0);

std::vector<double> params_r_flat__(num_params_r__);
std::vector<double> params_r_flat__(num_constrained_params__);
Eigen::Index size_iter__ = 0;
Eigen::Index flat_iter__ = 0;
for (auto&& param_name__ : names__) {
const auto param_vec__ = context.vals_r(param_name__);
for (Eigen::Index i = 0; i < num_params__[size_iter__]; ++i) {
for (Eigen::Index i = 0; i < constrain_param_sizes__[size_iter__]; ++i) {
params_r_flat__[flat_iter__] = param_vec__[i];
++flat_iter__;
}
++size_iter__;
}
vars.resize(params_r_flat__.size());
vars.resize(num_params_r__);
transform_inits_impl(params_r_flat__, params_i, vars, pstream__);
} // transform_inits()

Expand Down
Loading