diff --git a/include/expr.h b/include/expr.h index f01f164f..c1d2bd86 100644 --- a/include/expr.h +++ b/include/expr.h @@ -56,7 +56,16 @@ typedef struct expr // ------------------------------------------------------------------------ double *value; CSR_Matrix *jacobian; + CSC_Matrix *jacobian_csc; + int *csc_work; /* workspace for CSR-CSC conversion */ + + /* jacobian_csc_filled is only used for affine functions to avoid redundant + conversions. Could become relevant for non-affine functions if we start + supporting common subexpressions on the Python side. */ + bool jacobian_csc_filled; CSR_Matrix *wsum_hess; + CSR_Matrix *hess_term1; /* Jg^T D Jg workspace */ + CSR_Matrix *hess_term2; /* child wsum_hess workspace */ forward_fn forward; jacobian_init_fn jacobian_init; wsum_hess_init_fn wsum_hess_init; @@ -67,6 +76,7 @@ typedef struct expr // other things // ------------------------------------------------------------------------ is_affine_fn is_affine; + double *local_jac_diag; /* cached f'(g(x)) diagonal */ local_jacobian_fn local_jacobian; /* used by elementwise univariate atoms*/ local_wsum_hess_fn local_wsum_hess; /* used by elementwise univariate atoms*/ free_type_data_fn free_type_data; /* Cleanup for type-specific fields */ @@ -83,6 +93,10 @@ void init_expr(expr *node, int d1, int d2, int n_vars, forward_fn forward, void free_expr(expr *node); +/* Initialize CSC form of the Jacobian from the CSR Jacobian. + * Must be called after jacobian_init. */ +void jacobian_csc_init(expr *node); + /* Reference counting helpers */ void expr_retain(expr *node); diff --git a/src/elementwise_full_dom/common.c b/src/elementwise_full_dom/common.c index 762b9ece..6e2bcb79 100644 --- a/src/elementwise_full_dom/common.c +++ b/src/elementwise_full_dom/common.c @@ -1,6 +1,8 @@ #include "elementwise_full_dom.h" #include "subexpr.h" +#include "utils/CSC_Matrix.h" #include "utils/CSR_Matrix.h" +#include "utils/CSR_sum.h" #include #include #include @@ -20,7 +22,6 @@ void jacobian_init_elementwise(expr *node) } node->jacobian->p[node->size] = node->size; } - /* otherwise it will be a linear operator */ else { /* jacobian of h(x) = f(g(x)) is Jf @ Jg, and here Jf is diagonal */ @@ -28,6 +29,7 @@ void jacobian_init_elementwise(expr *node) CSR_Matrix *Jg = child->jacobian; node->jacobian = new_csr_matrix(Jg->m, Jg->n, Jg->nnz); node->dwork = (double *) malloc(node->size * sizeof(double)); + node->local_jac_diag = (double *) malloc(node->size * sizeof(double)); /* copy sparsity pattern of child */ memcpy(node->jacobian->p, Jg->p, sizeof(int) * (Jg->m + 1)); @@ -48,7 +50,8 @@ void eval_jacobian_elementwise(expr *node) /* jacobian of h(x) = f(g(x)) is Jf @ Jg, and here Jf is diagonal */ child->eval_jacobian(child); CSR_Matrix *Jg = child->jacobian; - node->local_jacobian(node, node->dwork); + node->local_jacobian(node, node->local_jac_diag); + memcpy(node->dwork, node->local_jac_diag, node->size * sizeof(double)); diag_csr_mult_fill_values(node->dwork, Jg, node->jacobian); } } @@ -59,7 +62,7 @@ void wsum_hess_init_elementwise(expr *node) int id = child->var_id; int i; - /* if the variable is a child*/ + /* if the variable is a child */ if (id != NOT_A_VARIABLE) { node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, node->size); @@ -75,11 +78,38 @@ void wsum_hess_init_elementwise(expr *node) node->wsum_hess->p[i] = node->size; } } - /* otherwise it will be a linear operator */ else { - linear_op_expr *lin_child = (linear_op_expr *) child; - node->wsum_hess = ATA_alloc(lin_child->A_csc); + /* Hessian of h(x) = w^T f(g(x) is term1 + term 2 where + term1 = J_g^T @ D @ J_g with D = sum_i w_i Hf_i, + term2 = sum_i (J_f^T w)_i^T Hg_i. + + For elementwise functions, D is diagonal. */ + jacobian_csc_init(child); + CSC_Matrix *Jg = child->jacobian_csc; + + if (child->is_affine(child)) + { + node->wsum_hess = ATA_alloc(Jg); + } + else + { + /* term1: Jg^T @ D @ Jg */ + node->hess_term1 = ATA_alloc(Jg); + + /* term2: child's Hessian */ + child->wsum_hess_init(child); + CSR_Matrix *Hg = child->wsum_hess; + node->hess_term2 = new_csr_matrix(Hg->m, Hg->n, Hg->nnz); + memcpy(node->hess_term2->p, Hg->p, (Hg->m + 1) * sizeof(int)); + memcpy(node->hess_term2->i, Hg->i, Hg->nnz * sizeof(int)); + + /* wsum_hess = term1 + term2 */ + int max_nnz = node->hess_term1->nnz + node->hess_term2->nnz; + node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, max_nnz); + sum_csr_matrices_fill_sparsity(node->hess_term1, node->hess_term2, + node->wsum_hess); + } } } @@ -93,10 +123,43 @@ void eval_wsum_hess_elementwise(expr *node, const double *w) } else { - /* Child will be a linear operator */ - linear_op_expr *lin_child = (linear_op_expr *) child; - node->local_wsum_hess(node, node->dwork, w); - ATDA_fill_values(lin_child->A_csc, node->dwork, node->wsum_hess); + if (child->is_affine(child)) + { + if (!child->jacobian_csc_filled) + { + csr_to_csc_fill_values(child->jacobian, child->jacobian_csc, + child->csc_work); + child->jacobian_csc_filled = true; + } + + node->local_wsum_hess(node, node->dwork, w); + ATDA_fill_values(child->jacobian_csc, node->dwork, node->wsum_hess); + } + else + { + /* refresh CSC jacobian values */ + csr_to_csc_fill_values(child->jacobian, child->jacobian_csc, + child->csc_work); + + /* term1: Jg^T @ D @ Jg */ + node->local_wsum_hess(node, node->dwork, w); + ATDA_fill_values(child->jacobian_csc, node->dwork, node->hess_term1); + + /* term2: child Hessian with weight Jf^T w */ + memcpy(node->dwork, node->local_jac_diag, node->size * sizeof(double)); + for (int k = 0; k < node->size; k++) + { + node->dwork[k] *= w[k]; + } + + child->eval_wsum_hess(child, node->dwork); + memcpy(node->hess_term2->x, child->wsum_hess->x, + child->wsum_hess->nnz * sizeof(double)); + + /* wsum_hess = term1 + term2 */ + sum_csr_matrices_fill_values(node->hess_term1, node->hess_term2, + node->wsum_hess); + } } } diff --git a/src/expr.c b/src/expr.c index 99a71755..faaa3e4b 100644 --- a/src/expr.c +++ b/src/expr.c @@ -16,6 +16,7 @@ * limitations under the License. */ #include "expr.h" +#include "utils/CSC_Matrix.h" #include "utils/int_double_pair.h" #include #include @@ -41,6 +42,12 @@ void init_expr(expr *node, int d1, int d2, int n_vars, forward_fn forward, node->free_type_data = free_type_data; } +void jacobian_csc_init(expr *node) +{ + node->csc_work = (int *) malloc(node->n_vars * sizeof(int)); + node->jacobian_csc = csr_to_csc_fill_sparsity(node->jacobian, node->csc_work); +} + void free_expr(expr *node) { if (node == NULL) return; @@ -63,8 +70,13 @@ void free_expr(expr *node) /* free value array and jacobian */ free(node->value); free_csr_matrix(node->jacobian); + free_csc_matrix(node->jacobian_csc); + free(node->csc_work); free_csr_matrix(node->wsum_hess); + free_csr_matrix(node->hess_term1); + free_csr_matrix(node->hess_term2); free(node->dwork); + free(node->local_jac_diag); free(node->iwork); node->value = NULL; node->jacobian = NULL; diff --git a/tests/all_tests.c b/tests/all_tests.c index df526bf8..d0cfa2ce 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -65,6 +65,7 @@ #include "wsum_hess/elementwise/test_trig.h" #include "wsum_hess/elementwise/test_xexp.h" #include "wsum_hess/test_broadcast.h" +#include "wsum_hess/test_chain_rule_wsum_hess.h" #include "wsum_hess/test_const_scalar_mult.h" #include "wsum_hess/test_const_vector_mult.h" #include "wsum_hess/test_hstack.h" @@ -259,6 +260,11 @@ int main(void) mu_run_test(test_wsum_hess_trace_log_variable, tests_run); mu_run_test(test_wsum_hess_trace_composite, tests_run); mu_run_test(test_wsum_hess_transpose, tests_run); + mu_run_test(test_wsum_hess_exp_sum, tests_run); + mu_run_test(test_wsum_hess_exp_sum_mult, tests_run); + mu_run_test(test_wsum_hess_exp_sum_matmul, tests_run); + mu_run_test(test_wsum_hess_sin_sum_axis0_matmul, tests_run); + mu_run_test(test_wsum_hess_logistic_sum_axis0_matmul, tests_run); printf("\n--- Utility Tests ---\n"); mu_run_test(test_cblas_ddot, tests_run); diff --git a/tests/wsum_hess/test_chain_rule_wsum_hess.h b/tests/wsum_hess/test_chain_rule_wsum_hess.h new file mode 100644 index 00000000..6d61bfea --- /dev/null +++ b/tests/wsum_hess/test_chain_rule_wsum_hess.h @@ -0,0 +1,100 @@ +#include "affine.h" +#include "bivariate.h" +#include "elementwise_full_dom.h" +#include "minunit.h" +#include "numerical_diff.h" + +const char *test_wsum_hess_exp_sum(void) +{ + double u_vals[3] = {1.0, 2.0, 3.0}; + double w = 1.0; + + expr *x = new_variable(3, 1, 0, 3); + expr *sum_x = new_sum(x, -1); + expr *exp_sum_x = new_exp(sum_x); + + mu_assert("check_wsum_hess failed", + check_wsum_hess(exp_sum_x, u_vals, &w, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(exp_sum_x); + return 0; +} + +const char *test_wsum_hess_exp_sum_mult(void) +{ + double u_vals[4] = {1.0, 2.0, 3.0, 4.0}; + double w = 1.0; + + expr *x = new_variable(2, 1, 0, 4); + expr *y = new_variable(2, 1, 2, 4); + expr *xy = new_elementwise_mult(x, y); + expr *sum_xy = new_sum(xy, -1); + expr *exp_sum_xy = new_exp(sum_xy); + + mu_assert("check_wsum_hess failed", + check_wsum_hess(exp_sum_xy, u_vals, &w, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(exp_sum_xy); + return 0; +} + +const char *test_wsum_hess_exp_sum_matmul(void) +{ + /* exp(sum(X @ Y)) where X is 2x3, Y is 3x2 + * n_vars = 6 + 6 = 12 */ + double u_vals[12] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.5, 1.5, 2.5, 0.1, 0.2, 0.3}; + double w = 1.0; + + expr *X = new_variable(2, 3, 0, 12); + expr *Y = new_variable(3, 2, 6, 12); + expr *XY = new_matmul(X, Y); + expr *sum_XY = new_sum(XY, -1); + expr *exp_sum_XY = new_exp(sum_XY); + + mu_assert("check_wsum_hess failed", + check_wsum_hess(exp_sum_XY, u_vals, &w, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(exp_sum_XY); + return 0; +} + +const char *test_wsum_hess_sin_sum_axis0_matmul(void) +{ + /* sin(sum(X @ Y, axis=0)) where X is 2x3, Y is 3x2 + * X@Y is 2x2, sum(axis=0) gives 1x2, sin gives 1x2 + * n_vars = 6 + 6 = 12 */ + double u_vals[12] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.5, 1.5, 2.5, 0.1, 0.2, 0.3}; + double w[2] = {1.0, 1.0}; + + expr *X = new_variable(2, 3, 0, 12); + expr *Y = new_variable(3, 2, 6, 12); + expr *XY = new_matmul(X, Y); + expr *sum_XY = new_sum(XY, 0); + expr *sin_sum_XY = new_sin(sum_XY); + + mu_assert("check_wsum_hess failed", + check_wsum_hess(sin_sum_XY, u_vals, w, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(sin_sum_XY); + return 0; +} + +const char *test_wsum_hess_logistic_sum_axis0_matmul(void) +{ + /* logistic(sum(X @ Y, axis=0)) where X is 2x3, Y is 3x2 + * n_vars = 6 + 6 = 12 */ + double u_vals[12] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.5, 1.5, 2.5, 0.1, 0.2, 0.3}; + double w[2] = {1.0, 1.0}; + + expr *X = new_variable(2, 3, 0, 12); + expr *Y = new_variable(3, 2, 6, 12); + expr *XY = new_matmul(X, Y); + expr *sum_XY = new_sum(XY, 0); + expr *logistic_sum_XY = new_logistic(sum_XY); + + mu_assert("check_wsum_hess failed", + check_wsum_hess(logistic_sum_XY, u_vals, w, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(logistic_sum_XY); + return 0; +}