From eafee08153ca061293c0de74f9013284bf175509 Mon Sep 17 00:00:00 2001 From: dance858 Date: Thu, 26 Mar 2026 12:19:11 -0700 Subject: [PATCH 1/2] jacobian chain rule --- src/affine/linear_op.c | 10 ++++- src/elementwise_full_dom/common.c | 19 +++++---- tests/all_tests.c | 3 ++ .../jacobian_tests/test_chain_rule_jacobian.h | 39 +++++++++++++++++++ 4 files changed, 61 insertions(+), 10 deletions(-) create mode 100644 tests/jacobian_tests/test_chain_rule_jacobian.h diff --git a/src/affine/linear_op.c b/src/affine/linear_op.c index a8d2863e..d9895e15 100644 --- a/src/affine/linear_op.c +++ b/src/affine/linear_op.c @@ -74,14 +74,20 @@ static void jacobian_init(expr *node) node->jacobian = ((linear_op_expr *) node)->A_csr; } +static void eval_jacobian(expr *node) +{ + /* Linear operator jacobian never changes - nothing to evaluate */ + (void) node; +} + expr *new_linear(expr *u, const CSR_Matrix *A, const double *b) { assert(u->d2 == 1); /* Allocate the type-specific struct */ linear_op_expr *lin_node = (linear_op_expr *) calloc(1, sizeof(linear_op_expr)); expr *node = &lin_node->base; - init_expr(node, A->m, 1, u->n_vars, forward, jacobian_init, NULL, is_affine, - NULL, NULL, free_type_data); + init_expr(node, A->m, 1, u->n_vars, forward, jacobian_init, eval_jacobian, + is_affine, NULL, NULL, free_type_data); node->left = u; expr_retain(u); diff --git a/src/elementwise_full_dom/common.c b/src/elementwise_full_dom/common.c index f5ac58ac..762b9ece 100644 --- a/src/elementwise_full_dom/common.c +++ b/src/elementwise_full_dom/common.c @@ -1,5 +1,6 @@ #include "elementwise_full_dom.h" #include "subexpr.h" +#include "utils/CSR_Matrix.h" #include #include #include @@ -22,14 +23,15 @@ void jacobian_init_elementwise(expr *node) /* otherwise it will be a linear operator */ else { - CSR_Matrix *J = ((linear_op_expr *) child)->A_csr; - node->jacobian = new_csr_matrix(J->m, J->n, J->nnz); - + /* jacobian of h(x) = f(g(x)) is Jf @ Jg, and here Jf is diagonal */ + child->jacobian_init(child); + CSR_Matrix *Jg = child->jacobian; + node->jacobian = new_csr_matrix(Jg->m, Jg->n, Jg->nnz); node->dwork = (double *) malloc(node->size * sizeof(double)); /* copy sparsity pattern of child */ - memcpy(node->jacobian->p, J->p, sizeof(int) * (J->m + 1)); - memcpy(node->jacobian->i, J->i, sizeof(int) * J->nnz); + memcpy(node->jacobian->p, Jg->p, sizeof(int) * (Jg->m + 1)); + memcpy(node->jacobian->i, Jg->i, sizeof(int) * Jg->nnz); } } @@ -43,10 +45,11 @@ void eval_jacobian_elementwise(expr *node) } else { - /* Child will be a linear operator */ - linear_op_expr *lin_child = (linear_op_expr *) child; + /* jacobian of h(x) = f(g(x)) is Jf @ Jg, and here Jf is diagonal */ + child->eval_jacobian(child); + CSR_Matrix *Jg = child->jacobian; node->local_jacobian(node, node->dwork); - diag_csr_mult_fill_values(node->dwork, lin_child->A_csr, node->jacobian); + diag_csr_mult_fill_values(node->dwork, Jg, node->jacobian); } } diff --git a/tests/all_tests.c b/tests/all_tests.c index 4f11151f..df526bf8 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -22,6 +22,7 @@ #include "forward_pass/test_prod_axis_one.h" #include "forward_pass/test_prod_axis_zero.h" #include "jacobian_tests/test_broadcast.h" +#include "jacobian_tests/test_chain_rule_jacobian.h" #include "jacobian_tests/test_composite_exp.h" #include "jacobian_tests/test_const_scalar_mult.h" #include "jacobian_tests/test_const_vector_mult.h" @@ -129,6 +130,8 @@ int main(void) mu_run_test(test_jacobian_log, tests_run); mu_run_test(test_jacobian_log_matrix, tests_run); mu_run_test(test_jacobian_composite_exp, tests_run); + mu_run_test(test_jacobian_exp_sum, tests_run); + mu_run_test(test_jacobian_exp_sum_mult, tests_run); mu_run_test(test_jacobian_composite_exp_add, tests_run); mu_run_test(test_jacobian_const_scalar_mult_log_vector, tests_run); mu_run_test(test_jacobian_const_scalar_mult_log_matrix, tests_run); diff --git a/tests/jacobian_tests/test_chain_rule_jacobian.h b/tests/jacobian_tests/test_chain_rule_jacobian.h new file mode 100644 index 00000000..b0c40153 --- /dev/null +++ b/tests/jacobian_tests/test_chain_rule_jacobian.h @@ -0,0 +1,39 @@ +#include "affine.h" +#include "bivariate.h" +#include "elementwise_full_dom.h" +#include "minunit.h" +#include "numerical_diff.h" + +const char *test_jacobian_exp_sum(void) +{ + double u_vals[3] = {1.0, 2.0, 3.0}; + + expr *x = new_variable(3, 1, 0, 3); + expr *sum_x = new_sum(x, -1); + expr *exp_sum_x = new_exp(sum_x); + + mu_assert("check_jacobian failed", + check_jacobian(exp_sum_x, u_vals, + NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(exp_sum_x); + return 0; +} + +const char *test_jacobian_exp_sum_mult(void) +{ + double u_vals[4] = {1.0, 2.0, 3.0, 4.0}; + + expr *x = new_variable(2, 1, 0, 4); + expr *y = new_variable(2, 1, 2, 4); + expr *xy = new_elementwise_mult(x, y); + expr *sum_xy = new_sum(xy, -1); + expr *exp_sum_xy = new_exp(sum_xy); + + mu_assert("check_jacobian failed", + check_jacobian(exp_sum_xy, u_vals, + NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(exp_sum_xy); + return 0; +} From 06df767c5b49f0e5de6a9cbcb83b192f551ce7a3 Mon Sep 17 00:00:00 2001 From: dance858 Date: Thu, 26 Mar 2026 12:21:12 -0700 Subject: [PATCH 2/2] run formatter --- tests/jacobian_tests/test_chain_rule_jacobian.h | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/jacobian_tests/test_chain_rule_jacobian.h b/tests/jacobian_tests/test_chain_rule_jacobian.h index b0c40153..f2070b37 100644 --- a/tests/jacobian_tests/test_chain_rule_jacobian.h +++ b/tests/jacobian_tests/test_chain_rule_jacobian.h @@ -13,8 +13,7 @@ const char *test_jacobian_exp_sum(void) expr *exp_sum_x = new_exp(sum_x); mu_assert("check_jacobian failed", - check_jacobian(exp_sum_x, u_vals, - NUMERICAL_DIFF_DEFAULT_H)); + check_jacobian(exp_sum_x, u_vals, NUMERICAL_DIFF_DEFAULT_H)); free_expr(exp_sum_x); return 0; @@ -31,8 +30,7 @@ const char *test_jacobian_exp_sum_mult(void) expr *exp_sum_xy = new_exp(sum_xy); mu_assert("check_jacobian failed", - check_jacobian(exp_sum_xy, u_vals, - NUMERICAL_DIFF_DEFAULT_H)); + check_jacobian(exp_sum_xy, u_vals, NUMERICAL_DIFF_DEFAULT_H)); free_expr(exp_sum_xy); return 0;