diff --git a/source/lib/src/gpu/tabulate.cu b/source/lib/src/gpu/tabulate.cu index 72c4c7a4e1..a3d9cc0e3a 100644 --- a/source/lib/src/gpu/tabulate.cu +++ b/source/lib/src/gpu/tabulate.cu @@ -1,3 +1,5 @@ +#include + #include "device.h" #include "tabulate.h" @@ -47,6 +49,46 @@ __device__ void GpuSyncThreads() { #endif } +template +__forceinline__ __device__ FPTYPE nextafter_device(const FPTYPE& from, + const FPTYPE& to); + +template <> +__forceinline__ __device__ float nextafter_device(const float& from, + const float& to) { + return nextafterf(from, to); +} + +template <> +__forceinline__ __device__ double nextafter_device(const double& from, + const double& to) { + return nextafter(from, to); +} + +template +__forceinline__ __device__ int locate_high_tail_xx(const FPTYPE& lower, + const FPTYPE& upper, + const FPTYPE& max, + const FPTYPE& stride0, + const FPTYPE& stride1) { + const FPTYPE boundary_xx = nextafter_device(max, lower); + const int first_stride = int((upper - lower) / stride0); + return first_stride + int((boundary_xx - upper) / stride1); +} + +template +__forceinline__ __device__ int locate_high_tail_xx_se_t(const FPTYPE& lower, + const FPTYPE& upper, + const FPTYPE& min, + const FPTYPE& max, + const FPTYPE& stride0, + const FPTYPE& stride1) { + const FPTYPE boundary_xx = nextafter_device(max, min); + const int first_stride = + int((lower - min) / stride1) + int((upper - lower) / stride0); + return first_stride + int((boundary_xx - upper) / stride1); +} + template __forceinline__ __device__ void locate_xx_se_a(FPTYPE& xx, int& table_idx, @@ -54,10 +96,14 @@ __forceinline__ __device__ void locate_xx_se_a(FPTYPE& xx, const FPTYPE& upper, const FPTYPE& max, const FPTYPE& stride0, - const FPTYPE& stride1) { + const FPTYPE& stride1, + FPTYPE& extrapolate_delta) { + const FPTYPE orig_xx = xx; + extrapolate_delta = (FPTYPE)0.; if (xx < lower) { table_idx = 0; xx = (FPTYPE)0.; + extrapolate_delta = orig_xx - lower; } else if (xx < upper) { table_idx = (int)((xx - lower) / stride0); xx -= (table_idx * stride0 + lower); @@ -66,9 +112,10 @@ __forceinline__ __device__ void locate_xx_se_a(FPTYPE& xx, table_idx = first_stride + (int)((xx - upper) / stride1); xx -= ((table_idx - first_stride) * stride1 + upper); } else { - table_idx = - int((upper - lower) / stride0) + (int)((max - upper) / stride1) - 1; - xx = (FPTYPE)0.; + int first_stride = int((upper - lower) / stride0); + table_idx = locate_high_tail_xx(lower, upper, max, stride0, stride1); + xx = max - ((table_idx - first_stride) * stride1 + upper); + extrapolate_delta = orig_xx - max; } } @@ -80,10 +127,14 @@ __forceinline__ __device__ void locate_xx_se_t(FPTYPE& xx, const FPTYPE& min, const FPTYPE& max, const FPTYPE& stride0, - const FPTYPE& stride1) { + const FPTYPE& stride1, + FPTYPE& extrapolate_delta) { + const FPTYPE orig_xx = xx; + extrapolate_delta = (FPTYPE)0.; if (xx < min) { table_idx = 0; xx = (FPTYPE)0.; + extrapolate_delta = orig_xx - min; } else if (xx < lower) { table_idx = (int)((xx - min) / stride1); xx -= (table_idx * stride1 + min); @@ -97,9 +148,12 @@ __forceinline__ __device__ void locate_xx_se_t(FPTYPE& xx, table_idx = first_stride + (int)((xx - upper) / stride1); xx -= ((table_idx - first_stride) * stride1 + upper); } else { - table_idx = int((lower - min) / stride1) + int((upper - lower) / stride0) + - (int)((max - upper) / stride1) - 1; - xx = (FPTYPE)0.; + int first_stride = + int((lower - min) / stride1) + int((upper - lower) / stride0); + table_idx = + locate_high_tail_xx_se_t(lower, upper, min, max, stride0, stride1); + xx = max - ((table_idx - first_stride) * stride1 + upper); + extrapolate_delta = orig_xx - max; } } @@ -112,10 +166,14 @@ __forceinline__ __device__ void locate_xx_se_t_tebd(FPTYPE& xx, const FPTYPE& min, const FPTYPE& max, const FPTYPE& stride0, - const FPTYPE& stride1) { + const FPTYPE& stride1, + FPTYPE& extrapolate_delta) { + const FPTYPE orig_xx = xx; + extrapolate_delta = (FPTYPE)0.; if (xx < min) { table_idx = 0; xx = (FPTYPE)0.; + extrapolate_delta = orig_xx - min; } else if (xx < lower) { table_idx = (int)((xx - min) / stride1); xx -= (table_idx * stride1 + min); @@ -129,9 +187,12 @@ __forceinline__ __device__ void locate_xx_se_t_tebd(FPTYPE& xx, table_idx = first_stride + (int)((xx - upper) / stride1); xx -= ((table_idx - first_stride) * stride1 + upper); } else { - table_idx = int((lower - min) / stride1) + int((upper - lower) / stride0) + - (int)((max - upper) / stride1) - 1; - xx = (FPTYPE)0.; + int first_stride = + int((lower - min) / stride1) + int((upper - lower) / stride0); + table_idx = + locate_high_tail_xx_se_t(lower, upper, min, max, stride0, stride1); + xx = max - ((table_idx - first_stride) * stride1 + upper); + extrapolate_delta = orig_xx - max; } } @@ -142,10 +203,14 @@ __forceinline__ __device__ void locate_xx_se_r(FPTYPE& xx, const FPTYPE& upper, const FPTYPE& max, const FPTYPE& stride0, - const FPTYPE& stride1) { + const FPTYPE& stride1, + FPTYPE& extrapolate_delta) { + const FPTYPE orig_xx = xx; + extrapolate_delta = (FPTYPE)0.; if (xx < lower) { table_idx = 0; xx = (FPTYPE)0.; + extrapolate_delta = orig_xx - lower; } else if (xx < upper) { table_idx = (int)((xx - lower) / stride0); xx -= (table_idx * stride0 + lower); @@ -154,9 +219,10 @@ __forceinline__ __device__ void locate_xx_se_r(FPTYPE& xx, table_idx = first_stride + (int)((xx - upper) / stride1); xx -= ((table_idx - first_stride) * stride1 + upper); } else { - table_idx = - int((upper - lower) / stride0) + (int)((max - upper) / stride1) - 1; - xx = (FPTYPE)0.; + int first_stride = int((upper - lower) / stride0); + table_idx = locate_high_tail_xx(lower, upper, max, stride0, stride1); + xx = max - ((table_idx - first_stride) * stride1 + upper); + extrapolate_delta = orig_xx - max; } } @@ -175,6 +241,32 @@ __forceinline__ __device__ void load_polynomial_params( var[5] = table[table_idx * last_layer_size * 6 + idx * 6 + 5]; } +template +__forceinline__ __device__ FPTYPE polynomial5(const FPTYPE var[6], + const FPTYPE& xx) { + return var[0] + + (var[1] + + (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * + xx; +} + +template +__forceinline__ __device__ FPTYPE polynomial5_grad(const FPTYPE var[6], + const FPTYPE& xx) { + return var[1] + ((FPTYPE)2. * var[2] + + ((FPTYPE)3. * var[3] + + ((FPTYPE)4. * var[4] + (FPTYPE)5. * var[5] * xx) * xx) * + xx) * + xx; +} + +template +__forceinline__ __device__ FPTYPE extrapolated_polynomial5( + const FPTYPE var[6], const FPTYPE& xx, const FPTYPE& extrapolate_delta) { + const FPTYPE grad = polynomial5_grad(var, xx); + return polynomial5(var, xx) + grad * extrapolate_delta; +} + template __forceinline__ __device__ FPTYPE dot(FPTYPE ll[4], FPTYPE rr[4]) { return ll[0] * rr[0] + ll[1] * rr[1] + ll[2] * rr[2] + ll[3] * rr[3]; @@ -239,15 +331,14 @@ __global__ void tabulate_fusion_se_a_fifth_order_polynomial( breakpoint = ii; } int table_idx = 0; - locate_xx_se_a(xx, table_idx, lower, upper, max, stride0, stride1); + FPTYPE extrapolate_delta = (FPTYPE)0.; + locate_xx_se_a(xx, table_idx, lower, upper, max, stride0, stride1, + extrapolate_delta); if (table_idx != mark_table_idx) { load_polynomial_params(var, table, table_idx, thread_idx, last_layer_size); } - FPTYPE res = - var[0] + - (var[1] + (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * - xx; + FPTYPE res = extrapolated_polynomial5(var, xx, extrapolate_delta); if (enable_se_atten) { FPTYPE t = two_embed[block_idx * nnei * last_layer_size + ii * last_layer_size + thread_idx]; @@ -334,16 +425,15 @@ __global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial( em[block_idx * nnei * MTILE + ii * 4 + 3]}; FPTYPE Csub = (FPTYPE)0.; FPTYPE sum[MTILE] = {(FPTYPE)0.}; - locate_xx_se_a(xx, table_idx, lower, upper, max, stride0, stride1); + FPTYPE extrapolate_delta = (FPTYPE)0.; + locate_xx_se_a(xx, table_idx, lower, upper, max, stride0, stride1, + extrapolate_delta); FPTYPE var[6]; for (int jj = lane_idx; jj < last_layer_size; jj += WARP_SIZE) { load_polynomial_params(var, table, table_idx, jj, last_layer_size); - FPTYPE res = - var[0] + - (var[1] + - (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * - xx; + FPTYPE res_grad = polynomial5_grad(var, xx); + FPTYPE res = polynomial5(var, xx) + res_grad * extrapolate_delta; FPTYPE oldres = res; FPTYPE t; if (enable_se_atten) { @@ -360,14 +450,8 @@ __global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial( res += reg_em[1] * iteratorA[1 * last_layer_size + jj]; res += reg_em[2] * iteratorA[2 * last_layer_size + jj]; res += reg_em[3] * iteratorA[3 * last_layer_size + jj]; - Csub += - (nnei - breakpoint) * - (var[1] + ((FPTYPE)2. * var[2] + - ((FPTYPE)3. * var[3] + - ((FPTYPE)4. * var[4] + (FPTYPE)5. * var[5] * xx) * xx) * - xx) * - xx) * - (enable_se_atten ? res * t + res : res); + Csub += (nnei - breakpoint) * res_grad * + (enable_se_atten ? res * t + res : res); if (enable_se_atten) { // from ii to ii + (nnei - breakpoint) for (int ii2 = ii; ii2 < ii + nnei - breakpoint; ii2++) { @@ -436,22 +520,16 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial( breakpoint = ii; } int table_idx = 0; - locate_xx_se_a(xx, table_idx, lower, upper, max, stride0, stride1); + FPTYPE extrapolate_delta = (FPTYPE)0.; + locate_xx_se_a(xx, table_idx, lower, upper, max, stride0, stride1, + extrapolate_delta); if (table_idx != mark_table_idx) { load_polynomial_params(var, table, table_idx, thread_idx, last_layer_size); } - FPTYPE res = - var[0] + - (var[1] + (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * - xx; - FPTYPE res_grad = - var[1] + ((FPTYPE)2. * var[2] + - ((FPTYPE)3. * var[3] + - ((FPTYPE)4. * var[4] + (FPTYPE)5. * var[5] * xx) * xx) * - xx) * - xx; + FPTYPE res_grad = polynomial5_grad(var, xx); + FPTYPE res = polynomial5(var, xx) + res_grad * extrapolate_delta; FPTYPE two_grad = 0.; if (enable_se_atten) { FPTYPE t = two_embed[block_idx * nnei * last_layer_size + @@ -527,16 +605,14 @@ __global__ void tabulate_fusion_se_t_fifth_order_polynomial( FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj]; FPTYPE tmp = xx; int table_idx = 0; - locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1); + FPTYPE extrapolate_delta = (FPTYPE)0.; + locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1, + extrapolate_delta); if (table_idx != mark_table_idx) { load_polynomial_params(var, table, table_idx, thread_idx, last_layer_size); } - FPTYPE res = - var[0] + - (var[1] + - (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * - xx; + FPTYPE res = extrapolated_polynomial5(var, xx, extrapolate_delta); sum += tmp * res; mark_table_idx = table_idx; @@ -577,26 +653,19 @@ __global__ void tabulate_fusion_se_t_grad_fifth_order_polynomial( FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj]; FPTYPE tmp = xx; int table_idx = 0; - locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1); + FPTYPE extrapolate_delta = (FPTYPE)0.; + locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1, + extrapolate_delta); FPTYPE sum = (FPTYPE)0.; FPTYPE Csub = (FPTYPE)0.; for (int kk = lane_idx; kk < last_layer_size; kk += WARP_SIZE) { FPTYPE var[6]; load_polynomial_params(var, table, table_idx, kk, last_layer_size); - FPTYPE res = - var[0] + - (var[1] + - (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * - xx; + FPTYPE res_grad = polynomial5_grad(var, xx); + FPTYPE res = polynomial5(var, xx) + res_grad * extrapolate_delta; sum += iteratorA[kk] * res; - Csub += - iteratorA[kk] * tmp * - (var[1] + ((FPTYPE)2. * var[2] + - ((FPTYPE)3. * var[3] + - ((FPTYPE)4. * var[4] + (FPTYPE)5. * var[5] * xx) * xx) * - xx) * - xx); + Csub += iteratorA[kk] * tmp * res_grad; } GpuSyncThreads(); warp_reduce(sum); @@ -640,20 +709,15 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial( FPTYPE var[6]; int table_idx = 0; - locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1); + FPTYPE extrapolate_delta = (FPTYPE)0.; + locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1, + extrapolate_delta); if (table_idx != mark_table_idx) { load_polynomial_params(var, table, table_idx, thread_idx, last_layer_size); } - FPTYPE res = - var[0] + - (var[1] + - (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * - xx; - FPTYPE res_grad = - var[1] + (2 * var[2] + - (3 * var[3] + (4 * var[4] + 5 * var[5] * xx) * xx) * xx) * - xx; + FPTYPE res_grad = polynomial5_grad(var, xx); + FPTYPE res = polynomial5(var, xx) + res_grad * extrapolate_delta; sum += (tmp * res_grad * dz_xx + dz_em * res); mark_table_idx = table_idx; @@ -695,19 +759,16 @@ __global__ void tabulate_fusion_se_t_tebd_fifth_order_polynomial( // Determine the table index based on the value of xx. int table_idx = 0; + FPTYPE extrapolate_delta = (FPTYPE)0.; locate_xx_se_t_tebd(xx, table_idx, lower, upper, -max, max, stride0, - stride1); + stride1, extrapolate_delta); // Serially loop through the 'last_layer_size' dimension to calculate all // features. for (int idx = 0; idx < last_layer_size; idx++) { FPTYPE var[6]; load_polynomial_params(var, table, table_idx, idx, last_layer_size); - FPTYPE res = - var[0] + - (var[1] + - (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * - xx; + FPTYPE res = extrapolated_polynomial5(var, xx, extrapolate_delta); // Calculate the unique 1D output index for the 4D tensor (block_idx, ii, // jj, idx). const int_64 out_idx = @@ -750,8 +811,9 @@ __global__ void tabulate_fusion_se_t_tebd_grad_fifth_order_polynomial( // Determine the table index based on the value of xx. FPTYPE xx = em_x[i]; int table_idx = 0; + FPTYPE extrapolate_delta = (FPTYPE)0.; locate_xx_se_t_tebd(xx, table_idx, lower, upper, -max, max, stride0, - stride1); + stride1, extrapolate_delta); // Accumulate the gradient contributions from all features. FPTYPE grad_sum = 0.0; @@ -760,12 +822,7 @@ __global__ void tabulate_fusion_se_t_tebd_grad_fifth_order_polynomial( load_polynomial_params(var, table, table_idx, idx, last_layer_size); // Calculate the derivative of the polynomial with respect to xx. - FPTYPE dres_dxx = - var[1] + ((FPTYPE)2. * var[2] + - ((FPTYPE)3. * var[3] + - ((FPTYPE)4. * var[4] + (FPTYPE)5. * var[5] * xx) * xx) * - xx) * - xx; + FPTYPE dres_dxx = polynomial5_grad(var, xx); // Read the incoming gradient from the previous layer. const int_64 dy_idx = @@ -819,8 +876,9 @@ __global__ void tabulate_fusion_se_t_tebd_grad_grad_fifth_order_polynomial( // Determine the table index based on the value of xx. int table_idx = 0; + FPTYPE extrapolate_delta = (FPTYPE)0.; locate_xx_se_t_tebd(xx, table_idx, lower, upper, -max, max, stride0, - stride1); + stride1, extrapolate_delta); // Serially loop through the 'last_layer_size' dimension. for (int idx = 0; idx < last_layer_size; idx++) { @@ -828,12 +886,7 @@ __global__ void tabulate_fusion_se_t_tebd_grad_grad_fifth_order_polynomial( load_polynomial_params(var, table, table_idx, idx, last_layer_size); // Calculate the derivative of the polynomial with respect to xx. - FPTYPE dres_dxx = - var[1] + ((FPTYPE)2. * var[2] + - ((FPTYPE)3. * var[3] + - ((FPTYPE)4. * var[4] + (FPTYPE)5. * var[5] * xx) * xx) * - xx) * - xx; + FPTYPE dres_dxx = polynomial5_grad(var, xx); // Apply the chain rule: dz/dy_idx = (dz/dxx) * (dxx/dy_idx) // which simplifies to dz_dy_dem_x_val * dres_dxx @@ -870,16 +923,15 @@ __global__ void tabulate_fusion_se_r_fifth_order_polynomial( for (int ii = 0; ii < nnei; ii++) { FPTYPE xx = em[block_idx * nnei + ii]; int table_idx = 0; - locate_xx_se_r(xx, table_idx, lower, upper, max, stride0, stride1); + FPTYPE extrapolate_delta = (FPTYPE)0.; + locate_xx_se_r(xx, table_idx, lower, upper, max, stride0, stride1, + extrapolate_delta); if (table_idx != mark_table_idx) { load_polynomial_params(var, table, table_idx, thread_idx, last_layer_size); } out[block_idx * nnei * last_layer_size + ii * last_layer_size + - thread_idx] = - var[0] + - (var[1] + (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * - xx; + thread_idx] = extrapolated_polynomial5(var, xx, extrapolate_delta); mark_table_idx = table_idx; } } @@ -907,17 +959,15 @@ __global__ void tabulate_fusion_se_r_grad_fifth_order_polynomial( int table_idx = 0; FPTYPE Csub = (FPTYPE)0.; - locate_xx_se_r(xx, table_idx, lower, upper, max, stride0, stride1); + FPTYPE extrapolate_delta = (FPTYPE)0.; + locate_xx_se_r(xx, table_idx, lower, upper, max, stride0, stride1, + extrapolate_delta); FPTYPE var[6]; for (int jj = lane_idx; jj < last_layer_size; jj += WARP_SIZE) { load_polynomial_params(var, table, table_idx, jj, last_layer_size); Csub += - (var[1] + ((FPTYPE)2. * var[2] + - ((FPTYPE)3. * var[3] + - ((FPTYPE)4. * var[4] + (FPTYPE)5. * var[5] * xx) * xx) * - xx) * - xx) * + polynomial5_grad(var, xx) * dy[block_idx * nnei * last_layer_size + ii * last_layer_size + jj]; } GpuSyncThreads(); @@ -954,17 +1004,14 @@ __global__ void tabulate_fusion_se_r_grad_grad_fifth_order_polynomial( for (int ii = 0; ii < nnei; ii++) { FPTYPE xx = em[block_idx * nnei + ii]; int table_idx = 0; - locate_xx_se_r(xx, table_idx, lower, upper, max, stride0, stride1); + FPTYPE extrapolate_delta = (FPTYPE)0.; + locate_xx_se_r(xx, table_idx, lower, upper, max, stride0, stride1, + extrapolate_delta); if (table_idx != mark_table_idx) { load_polynomial_params(var, table, table_idx, thread_idx, last_layer_size); } - FPTYPE res_grad = - var[1] + ((FPTYPE)2. * var[2] + - ((FPTYPE)3. * var[3] + - ((FPTYPE)4. * var[4] + (FPTYPE)5. * var[5] * xx) * xx) * - xx) * - xx; + FPTYPE res_grad = polynomial5_grad(var, xx); mark_table_idx = table_idx; dz_dy[block_idx * nnei * last_layer_size + ii * last_layer_size + thread_idx] = dz_dy_dem[block_idx * nnei + ii] * res_grad; diff --git a/source/lib/src/tabulate.cc b/source/lib/src/tabulate.cc index e3b1b770ca..19fcae13ba 100644 --- a/source/lib/src/tabulate.cc +++ b/source/lib/src/tabulate.cc @@ -4,6 +4,7 @@ #include #include +#include #include #include /* @@ -16,6 +17,30 @@ xx: indicate the inputs value; table_idx: indicate the location of table info of input value xx; */ +template +inline int locate_high_tail_xx(const FPTYPE& lower, + const FPTYPE& upper, + const FPTYPE& max, + const FPTYPE& stride0, + const FPTYPE& stride1) { + const FPTYPE boundary_xx = std::nextafter(max, lower); + const int first_stride = int((upper - lower) / stride0); + return first_stride + int((boundary_xx - upper) / stride1); +} + +template +inline int locate_high_tail_xx_se_t(const FPTYPE& lower, + const FPTYPE& upper, + const FPTYPE& min, + const FPTYPE& max, + const FPTYPE& stride0, + const FPTYPE& stride1) { + const FPTYPE boundary_xx = std::nextafter(max, min); + const int first_stride = + int((lower - min) / stride1) + int((upper - lower) / stride0); + return first_stride + int((boundary_xx - upper) / stride1); +} + template inline void locate_xx(const FPTYPE& lower, const FPTYPE& upper, @@ -23,10 +48,14 @@ inline void locate_xx(const FPTYPE& lower, const FPTYPE& stride0, const FPTYPE& stride1, FPTYPE& xx, - int& table_idx) { + int& table_idx, + FPTYPE& extrapolate_delta) { + const FPTYPE orig_xx = xx; + extrapolate_delta = (FPTYPE)0.; if (xx < lower) { table_idx = 0; xx = (FPTYPE)0.; + extrapolate_delta = orig_xx - lower; } else if (xx < upper) { table_idx = (int)((xx - lower) / stride0); xx -= (table_idx * stride0 + lower); @@ -35,9 +64,10 @@ inline void locate_xx(const FPTYPE& lower, table_idx = first_stride + (int)((xx - upper) / stride1); xx -= ((table_idx - first_stride) * stride1 + upper); } else { - table_idx = - int((upper - lower) / stride0) + (int)((max - upper) / stride1) - 1; - xx = (FPTYPE)0.; + int first_stride = int((upper - lower) / stride0); + table_idx = locate_high_tail_xx(lower, upper, max, stride0, stride1); + xx = max - ((table_idx - first_stride) * stride1 + upper); + extrapolate_delta = orig_xx - max; } } @@ -49,10 +79,14 @@ inline void locate_xx_se_t(const FPTYPE& lower, const FPTYPE& stride0, const FPTYPE& stride1, FPTYPE& xx, - int& table_idx) { + int& table_idx, + FPTYPE& extrapolate_delta) { + const FPTYPE orig_xx = xx; + extrapolate_delta = (FPTYPE)0.; if (xx < min) { table_idx = 0; xx = (FPTYPE)0.; + extrapolate_delta = orig_xx - min; } else if (xx < lower) { table_idx = (int)((xx - min) / stride1); xx -= (table_idx * stride1 + min); @@ -66,9 +100,12 @@ inline void locate_xx_se_t(const FPTYPE& lower, table_idx = first_stride + (int)((xx - upper) / stride1); xx -= ((table_idx - first_stride) * stride1 + upper); } else { - table_idx = int((lower - min) / stride1) + int((upper - lower) / stride0) + - (int)((max - upper) / stride1) - 1; - xx = (FPTYPE)0.; + int first_stride = + int((lower - min) / stride1) + int((upper - lower) / stride0); + table_idx = + locate_high_tail_xx_se_t(lower, upper, min, max, stride0, stride1); + xx = max - ((table_idx - first_stride) * stride1 + upper); + extrapolate_delta = orig_xx - max; } } @@ -77,6 +114,44 @@ inline FPTYPE dot(FPTYPE a[4], FPTYPE b[4]) { return a[0] * b[0] + a[1] * b[1] + a[2] * b[2] + a[3] * b[3]; } +template +inline FPTYPE polynomial5(const FPTYPE& a0, + const FPTYPE& a1, + const FPTYPE& a2, + const FPTYPE& a3, + const FPTYPE& a4, + const FPTYPE& a5, + const FPTYPE& xx) { + return a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx; +} + +template +inline FPTYPE polynomial5_grad(const FPTYPE& a1, + const FPTYPE& a2, + const FPTYPE& a3, + const FPTYPE& a4, + const FPTYPE& a5, + const FPTYPE& xx) { + return a1 + + ((FPTYPE)2. * a2 + + ((FPTYPE)3. * a3 + ((FPTYPE)4. * a4 + (FPTYPE)5. * a5 * xx) * xx) * + xx) * + xx; +} + +template +inline FPTYPE extrapolated_polynomial5(const FPTYPE& a0, + const FPTYPE& a1, + const FPTYPE& a2, + const FPTYPE& a3, + const FPTYPE& a4, + const FPTYPE& a5, + const FPTYPE& xx, + const FPTYPE& extrapolate_delta) { + const FPTYPE grad = polynomial5_grad(a1, a2, a3, a4, a5, xx); + return polynomial5(a0, a1, a2, a3, a4, a5, xx) + grad * extrapolate_delta; +} + template void deepmd::tabulate_fusion_se_a_cpu(FPTYPE* out, const FPTYPE* table, @@ -112,7 +187,9 @@ void deepmd::tabulate_fusion_se_a_cpu(FPTYPE* out, unloop = true; } int table_idx = 0; - locate_xx(lower, upper, _max, stride0, stride1, xx, table_idx); + FPTYPE extrapolate_delta = (FPTYPE)0.; + locate_xx(lower, upper, _max, stride0, stride1, xx, table_idx, + extrapolate_delta); for (int kk = 0; kk < last_layer_size; kk++) { FPTYPE a0 = table[table_idx * last_layer_size * 6 + 6 * kk + 0]; FPTYPE a1 = table[table_idx * last_layer_size * 6 + 6 * kk + 1]; @@ -120,8 +197,8 @@ void deepmd::tabulate_fusion_se_a_cpu(FPTYPE* out, FPTYPE a3 = table[table_idx * last_layer_size * 6 + 6 * kk + 3]; FPTYPE a4 = table[table_idx * last_layer_size * 6 + 6 * kk + 4]; FPTYPE a5 = table[table_idx * last_layer_size * 6 + 6 * kk + 5]; - FPTYPE var = - a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx; + FPTYPE var = extrapolated_polynomial5(a0, a1, a2, a3, a4, a5, xx, + extrapolate_delta); if (enable_se_atten) { FPTYPE t = two_embed[ii * nnei * last_layer_size + jj * last_layer_size + kk]; @@ -199,7 +276,9 @@ void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x, unloop = true; } int table_idx = 0; - locate_xx(lower, upper, _max, stride0, stride1, xx, table_idx); + FPTYPE extrapolate_delta = (FPTYPE)0.; + locate_xx(lower, upper, _max, stride0, stride1, xx, table_idx, + extrapolate_delta); FPTYPE grad = (FPTYPE)0.0; for (int kk = 0; kk < last_layer_size; kk++) { rr[0] = dy[ii * last_layer_size * 4 + 0 * last_layer_size + kk]; @@ -212,10 +291,9 @@ void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x, FPTYPE a3 = table[table_idx * last_layer_size * 6 + 6 * kk + 3]; FPTYPE a4 = table[table_idx * last_layer_size * 6 + 6 * kk + 4]; FPTYPE a5 = table[table_idx * last_layer_size * 6 + 6 * kk + 5]; + FPTYPE g = polynomial5_grad(a1, a2, a3, a4, a5, xx); FPTYPE res = - a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx; - FPTYPE g = - (a1 + (2 * a2 + (3 * a3 + (4 * a4 + 5 * a5 * xx) * xx) * xx) * xx); + polynomial5(a0, a1, a2, a3, a4, a5, xx) + g * extrapolate_delta; FPTYPE resold = res; if (enable_se_atten) { FPTYPE t = two_embed[ii * nnei * last_layer_size + @@ -302,7 +380,9 @@ void deepmd::tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy, unloop = true; } int table_idx = 0; - locate_xx(lower, upper, _max, stride0, stride1, xx, table_idx); + FPTYPE extrapolate_delta = (FPTYPE)0.; + locate_xx(lower, upper, _max, stride0, stride1, xx, table_idx, + extrapolate_delta); for (int kk = 0; kk < last_layer_size; kk++) { FPTYPE a0 = table[table_idx * last_layer_size * 6 + 6 * kk + 0]; FPTYPE a1 = table[table_idx * last_layer_size * 6 + 6 * kk + 1]; @@ -310,14 +390,9 @@ void deepmd::tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy, FPTYPE a3 = table[table_idx * last_layer_size * 6 + 6 * kk + 3]; FPTYPE a4 = table[table_idx * last_layer_size * 6 + 6 * kk + 4]; FPTYPE a5 = table[table_idx * last_layer_size * 6 + 6 * kk + 5]; - FPTYPE var = - a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx; - FPTYPE var_grad = - a1 + - ((FPTYPE)2. * a2 + - ((FPTYPE)3. * a3 + ((FPTYPE)4. * a4 + (FPTYPE)5. * a5 * xx) * xx) * - xx) * - xx; + FPTYPE var_grad = polynomial5_grad(a1, a2, a3, a4, a5, xx); + FPTYPE var = polynomial5(a0, a1, a2, a3, a4, a5, xx) + + var_grad * extrapolate_delta; FPTYPE two_grad = 0.; if (enable_se_atten) { FPTYPE t = two_embed[ii * nnei * last_layer_size + @@ -408,8 +483,9 @@ void deepmd::tabulate_fusion_se_t_cpu(FPTYPE* out, FPTYPE xx = em_x[ii * nnei_i * nnei_j + jj * nnei_j + kk]; FPTYPE ll = xx; int table_idx = 0; + FPTYPE extrapolate_delta = (FPTYPE)0.; locate_xx_se_t(lower, upper, -_max, _max, stride0, stride1, xx, - table_idx); + table_idx, extrapolate_delta); for (int mm = 0; mm < last_layer_size; mm++) { FPTYPE a0 = table[table_idx * last_layer_size * 6 + 6 * mm + 0]; FPTYPE a1 = table[table_idx * last_layer_size * 6 + 6 * mm + 1]; @@ -417,8 +493,8 @@ void deepmd::tabulate_fusion_se_t_cpu(FPTYPE* out, FPTYPE a3 = table[table_idx * last_layer_size * 6 + 6 * mm + 3]; FPTYPE a4 = table[table_idx * last_layer_size * 6 + 6 * mm + 4]; FPTYPE a5 = table[table_idx * last_layer_size * 6 + 6 * mm + 5]; - FPTYPE var = - a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx; + FPTYPE var = extrapolated_polynomial5(a0, a1, a2, a3, a4, a5, xx, + extrapolate_delta); out[ii * last_layer_size + mm] += var * ll; } } @@ -457,8 +533,9 @@ void deepmd::tabulate_fusion_se_t_grad_cpu(FPTYPE* dy_dem_x, FPTYPE xx = em_x[ii * nnei_i * nnei_j + jj * nnei_j + kk]; ll = xx; int table_idx = 0; + FPTYPE extrapolate_delta = (FPTYPE)0.; locate_xx_se_t(lower, upper, -_max, _max, stride0, stride1, xx, - table_idx); + table_idx, extrapolate_delta); FPTYPE grad = (FPTYPE)0.0; for (int mm = 0; mm < last_layer_size; mm++) { rr = dy[ii * last_layer_size + mm]; @@ -468,15 +545,11 @@ void deepmd::tabulate_fusion_se_t_grad_cpu(FPTYPE* dy_dem_x, FPTYPE a3 = table[table_idx * last_layer_size * 6 + 6 * mm + 3]; FPTYPE a4 = table[table_idx * last_layer_size * 6 + 6 * mm + 4]; FPTYPE a5 = table[table_idx * last_layer_size * 6 + 6 * mm + 5]; - FPTYPE res = - a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx; + FPTYPE res_grad = polynomial5_grad(a1, a2, a3, a4, a5, xx); + FPTYPE res = polynomial5(a0, a1, a2, a3, a4, a5, xx) + + res_grad * extrapolate_delta; - grad += (a1 + ((FPTYPE)2. * a2 + - ((FPTYPE)3. * a3 + - ((FPTYPE)4. * a4 + (FPTYPE)5. * a5 * xx) * xx) * - xx) * - xx) * - ll * rr; + grad += res_grad * ll * rr; dy_dem[ii * nnei_i * nnei_j + jj * nnei_j + kk] += res * rr; } dy_dem_x[ii * nnei_i * nnei_j + jj * nnei_j + kk] = grad; @@ -515,8 +588,9 @@ void deepmd::tabulate_fusion_se_t_grad_grad_cpu(FPTYPE* dz_dy, FPTYPE dz_xx = dz_dy_dem_x[ii * nnei_i * nnei_j + jj * nnei_j + kk]; int table_idx = 0; + FPTYPE extrapolate_delta = (FPTYPE)0.; locate_xx_se_t(lower, upper, -_max, _max, stride0, stride1, xx, - table_idx); + table_idx, extrapolate_delta); for (int mm = 0; mm < last_layer_size; mm++) { FPTYPE a0 = table[table_idx * last_layer_size * 6 + 6 * mm + 0]; FPTYPE a1 = table[table_idx * last_layer_size * 6 + 6 * mm + 1]; @@ -524,14 +598,9 @@ void deepmd::tabulate_fusion_se_t_grad_grad_cpu(FPTYPE* dz_dy, FPTYPE a3 = table[table_idx * last_layer_size * 6 + 6 * mm + 3]; FPTYPE a4 = table[table_idx * last_layer_size * 6 + 6 * mm + 4]; FPTYPE a5 = table[table_idx * last_layer_size * 6 + 6 * mm + 5]; - FPTYPE var = - a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx; - FPTYPE var_grad = - a1 + ((FPTYPE)2. * a2 + - ((FPTYPE)3. * a3 + - ((FPTYPE)4. * a4 + (FPTYPE)5. * a5 * xx) * xx) * - xx) * - xx; + FPTYPE var_grad = polynomial5_grad(a1, a2, a3, a4, a5, xx); + FPTYPE var = polynomial5(a0, a1, a2, a3, a4, a5, xx) + + var_grad * extrapolate_delta; dz_dy[ii * last_layer_size + mm] += var * dz_em + dz_xx * var_grad * tmp; @@ -564,8 +633,9 @@ void deepmd::tabulate_fusion_se_t_tebd_cpu(FPTYPE* out, for (int kk = 0; kk < nnei_j; kk++) { FPTYPE xx = em_x[ii * nnei_i * nnei_j + jj * nnei_j + kk]; int table_idx = 0; + FPTYPE extrapolate_delta = (FPTYPE)0.; locate_xx_se_t(lower, upper, -_max, _max, stride0, stride1, xx, - table_idx); + table_idx, extrapolate_delta); // For SE_TEBD, we preserve the full nt_i x nt_j x ng structure // instead of reducing it like SE_T does @@ -577,8 +647,8 @@ void deepmd::tabulate_fusion_se_t_tebd_cpu(FPTYPE* out, FPTYPE a4 = table[table_idx * last_layer_size * 6 + 6 * mm + 4]; FPTYPE a5 = table[table_idx * last_layer_size * 6 + 6 * mm + 5]; - FPTYPE res = a0 + a1 * xx + a2 * xx * xx + a3 * xx * xx * xx + - a4 * xx * xx * xx * xx + a5 * xx * xx * xx * xx * xx; + FPTYPE res = extrapolated_polynomial5(a0, a1, a2, a3, a4, a5, xx, + extrapolate_delta); // Store result preserving the nt_i x nt_j structure out[ii * nnei_i * nnei_j * last_layer_size + @@ -613,8 +683,9 @@ void deepmd::tabulate_fusion_se_t_tebd_grad_cpu(FPTYPE* dy_dem_x, for (int kk = 0; kk < nnei_j; kk++) { FPTYPE xx = em_x[ii * nnei_i * nnei_j + jj * nnei_j + kk]; int table_idx = 0; + FPTYPE extrapolate_delta = (FPTYPE)0.; locate_xx_se_t(lower, upper, -_max, _max, stride0, stride1, xx, - table_idx); + table_idx, extrapolate_delta); FPTYPE grad_sum = 0.0; for (int mm = 0; mm < last_layer_size; mm++) { @@ -624,9 +695,7 @@ void deepmd::tabulate_fusion_se_t_tebd_grad_cpu(FPTYPE* dy_dem_x, FPTYPE a4 = table[table_idx * last_layer_size * 6 + 6 * mm + 4]; FPTYPE a5 = table[table_idx * last_layer_size * 6 + 6 * mm + 5]; - FPTYPE dres_dxx = a1 + 2.0 * a2 * xx + 3.0 * a3 * xx * xx + - 4.0 * a4 * xx * xx * xx + - 5.0 * a5 * xx * xx * xx * xx; + FPTYPE dres_dxx = polynomial5_grad(a1, a2, a3, a4, a5, xx); FPTYPE dy_val = dy[ii * nnei_i * nnei_j * last_layer_size + @@ -665,8 +734,9 @@ void deepmd::tabulate_fusion_se_t_tebd_grad_grad_cpu( for (int kk = 0; kk < nnei_j; kk++) { FPTYPE xx = em_x[ii * nnei_i * nnei_j + jj * nnei_j + kk]; int table_idx = 0; + FPTYPE extrapolate_delta = (FPTYPE)0.; locate_xx_se_t(lower, upper, -_max, _max, stride0, stride1, xx, - table_idx); + table_idx, extrapolate_delta); FPTYPE dz_dy_dem_x_val = dz_dy_dem_x[ii * nnei_i * nnei_j + jj * nnei_j + kk]; @@ -678,9 +748,7 @@ void deepmd::tabulate_fusion_se_t_tebd_grad_grad_cpu( FPTYPE a4 = table[table_idx * last_layer_size * 6 + 6 * mm + 4]; FPTYPE a5 = table[table_idx * last_layer_size * 6 + 6 * mm + 5]; - FPTYPE dres_dxx = a1 + 2.0 * a2 * xx + 3.0 * a3 * xx * xx + - 4.0 * a4 * xx * xx * xx + - 5.0 * a5 * xx * xx * xx * xx; + FPTYPE dres_dxx = polynomial5_grad(a1, a2, a3, a4, a5, xx); dz_dy[ii * nnei_i * nnei_j * last_layer_size + jj * nnei_j * last_layer_size + kk * last_layer_size + mm] = @@ -712,7 +780,9 @@ void deepmd::tabulate_fusion_se_r_cpu(FPTYPE* out, for (int jj = 0; jj < nnei; jj++) { FPTYPE xx = em[ii * nnei + jj]; int table_idx = 0; - locate_xx(lower, upper, _max, stride0, stride1, xx, table_idx); + FPTYPE extrapolate_delta = (FPTYPE)0.; + locate_xx(lower, upper, _max, stride0, stride1, xx, table_idx, + extrapolate_delta); for (int kk = 0; kk < last_layer_size; kk++) { FPTYPE a0 = table[table_idx * last_layer_size * 6 + 6 * kk + 0]; FPTYPE a1 = table[table_idx * last_layer_size * 6 + 6 * kk + 1]; @@ -721,7 +791,8 @@ void deepmd::tabulate_fusion_se_r_cpu(FPTYPE* out, FPTYPE a4 = table[table_idx * last_layer_size * 6 + 6 * kk + 4]; FPTYPE a5 = table[table_idx * last_layer_size * 6 + 6 * kk + 5]; out[ii * last_layer_size * nnei + jj * last_layer_size + kk] = - a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx; + extrapolated_polynomial5(a0, a1, a2, a3, a4, a5, xx, + extrapolate_delta); } } } @@ -750,7 +821,9 @@ void deepmd::tabulate_fusion_se_r_grad_cpu(FPTYPE* dy_dem, // construct the dy/dx FPTYPE xx = em[ii * nnei + jj]; int table_idx = 0; - locate_xx(lower, upper, _max, stride0, stride1, xx, table_idx); + FPTYPE extrapolate_delta = (FPTYPE)0.; + locate_xx(lower, upper, _max, stride0, stride1, xx, table_idx, + extrapolate_delta); FPTYPE grad = (FPTYPE)0.0; for (int kk = 0; kk < last_layer_size; kk++) { FPTYPE a0 = table[table_idx * last_layer_size * 6 + 6 * kk + 0]; @@ -759,11 +832,7 @@ void deepmd::tabulate_fusion_se_r_grad_cpu(FPTYPE* dy_dem, FPTYPE a3 = table[table_idx * last_layer_size * 6 + 6 * kk + 3]; FPTYPE a4 = table[table_idx * last_layer_size * 6 + 6 * kk + 4]; FPTYPE a5 = table[table_idx * last_layer_size * 6 + 6 * kk + 5]; - grad += (a1 + ((FPTYPE)2. * a2 + - ((FPTYPE)3. * a3 + - ((FPTYPE)4. * a4 + (FPTYPE)5. * a5 * xx) * xx) * - xx) * - xx) * + grad += polynomial5_grad(a1, a2, a3, a4, a5, xx) * dy[ii * last_layer_size * nnei + jj * last_layer_size + kk]; } dy_dem[ii * nnei + jj] = grad; @@ -793,7 +862,9 @@ void deepmd::tabulate_fusion_se_r_grad_grad_cpu(FPTYPE* dz_dy, for (int jj = 0; jj < nnei; jj++) { FPTYPE xx = em[ii * nnei + jj]; int table_idx = 0; - locate_xx(lower, upper, _max, stride0, stride1, xx, table_idx); + FPTYPE extrapolate_delta = (FPTYPE)0.; + locate_xx(lower, upper, _max, stride0, stride1, xx, table_idx, + extrapolate_delta); for (int kk = 0; kk < last_layer_size; kk++) { FPTYPE a0 = table[table_idx * last_layer_size * 6 + 6 * kk + 0]; FPTYPE a1 = table[table_idx * last_layer_size * 6 + 6 * kk + 1]; @@ -801,12 +872,7 @@ void deepmd::tabulate_fusion_se_r_grad_grad_cpu(FPTYPE* dz_dy, FPTYPE a3 = table[table_idx * last_layer_size * 6 + 6 * kk + 3]; FPTYPE a4 = table[table_idx * last_layer_size * 6 + 6 * kk + 4]; FPTYPE a5 = table[table_idx * last_layer_size * 6 + 6 * kk + 5]; - FPTYPE var_grad = - a1 + - ((FPTYPE)2. * a2 + - ((FPTYPE)3. * a3 + ((FPTYPE)4. * a4 + (FPTYPE)5. * a5 * xx) * xx) * - xx) * - xx; + FPTYPE var_grad = polynomial5_grad(a1, a2, a3, a4, a5, xx); dz_dy[ii * last_layer_size * nnei + jj * last_layer_size + kk] = dz_dy_dem[ii * nnei + jj] * var_grad; } diff --git a/source/lib/tests/test_tabulate_extrapolate.cc b/source/lib/tests/test_tabulate_extrapolate.cc new file mode 100644 index 0000000000..a231a8476b --- /dev/null +++ b/source/lib/tests/test_tabulate_extrapolate.cc @@ -0,0 +1,959 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +#include + +#include +#include + +#include "device.h" +#include "tabulate.h" + +namespace { + +constexpr double kLower = -1.0; +constexpr double kUpper = 1.0; +constexpr double kMax = 2.0; +constexpr double kMin = -kMax; +constexpr double kStride0 = 1.0; +constexpr double kStride1 = 1.0; +constexpr int kLastLayerSize = 1; +constexpr double kFiniteDiffStep = 1e-6; +constexpr double kTolerance = 1e-10; +constexpr double kTwoEmbed = 0.75; + +const std::vector kTableInfo = {kLower, kUpper, kMax, + kStride0, kStride1, -1.0}; +const std::vector kOffGridTableInfo = {kLower, kUpper, 2.5, + kStride0, kStride1, -1.0}; + +const std::vector kTable = { + 10.0, 2.0, 3.0, 0.25, -0.5, 0.1, // + 20.0, -1.0, 0.5, -0.2, 0.25, -0.05, // + 30.0, 4.0, -2.0, 1.0, 0.1, 0.2, // + 40.0, -3.0, 1.0, 0.5, -0.25, 0.05, +}; + +struct LocatedXx { + int table_idx; + double offset; + double extrapolate_delta; +}; + +double poly5(const double* coeff, const double xx) { + return coeff[0] + + (coeff[1] + + (coeff[2] + (coeff[3] + (coeff[4] + coeff[5] * xx) * xx) * xx) * xx) * + xx; +} + +double poly5_grad(const double* coeff, const double xx) { + return coeff[1] + + (2.0 * coeff[2] + + (3.0 * coeff[3] + (4.0 * coeff[4] + 5.0 * coeff[5] * xx) * xx) * xx) * + xx; +} + +LocatedXx locate_se_a_or_r(const double xx, + const std::vector& table_info) { + const double lower = table_info[0]; + const double upper = table_info[1]; + const double max = table_info[2]; + const double stride0 = table_info[3]; + const double stride1 = table_info[4]; + if (xx < lower) { + return {0, 0.0, xx - lower}; + } + if (xx < upper) { + const int table_idx = static_cast((xx - lower) / stride0); + return {table_idx, xx - (table_idx * stride0 + lower), 0.0}; + } + if (xx < max) { + const int first_stride = static_cast((upper - lower) / stride0); + const int table_idx = + first_stride + static_cast((xx - upper) / stride1); + return {table_idx, xx - ((table_idx - first_stride) * stride1 + upper), + 0.0}; + } + + const int first_stride = static_cast((upper - lower) / stride0); + const double boundary_xx = std::nextafter(max, lower); + const int table_idx = + first_stride + static_cast((boundary_xx - upper) / stride1); + return {table_idx, max - ((table_idx - first_stride) * stride1 + upper), + xx - max}; +} + +LocatedXx locate_se_a_or_r(const double xx) { + return locate_se_a_or_r(xx, kTableInfo); +} + +LocatedXx locate_se_t(const double xx, const std::vector& table_info) { + const double lower = table_info[0]; + const double upper = table_info[1]; + const double max = table_info[2]; + const double min = -max; + const double stride0 = table_info[3]; + const double stride1 = table_info[4]; + if (xx < min) { + return {0, 0.0, xx - min}; + } + if (xx < lower) { + const int table_idx = static_cast((xx - min) / stride1); + return {table_idx, xx - (table_idx * stride1 + min), 0.0}; + } + if (xx < upper) { + const int first_stride = static_cast((lower - min) / stride1); + const int table_idx = + first_stride + static_cast((xx - lower) / stride0); + return {table_idx, xx - ((table_idx - first_stride) * stride0 + lower), + 0.0}; + } + if (xx < max) { + const int first_stride = static_cast((lower - min) / stride1) + + static_cast((upper - lower) / stride0); + const int table_idx = + first_stride + static_cast((xx - upper) / stride1); + return {table_idx, xx - ((table_idx - first_stride) * stride1 + upper), + 0.0}; + } + + const int first_stride = static_cast((lower - min) / stride1) + + static_cast((upper - lower) / stride0); + const double boundary_xx = std::nextafter(max, min); + const int table_idx = + first_stride + static_cast((boundary_xx - upper) / stride1); + return {table_idx, max - ((table_idx - first_stride) * stride1 + upper), + xx - max}; +} + +LocatedXx locate_se_t(const double xx) { return locate_se_t(xx, kTableInfo); } + +double expected_table_value(const LocatedXx& located) { + const double* coeff = &kTable[located.table_idx * 6]; + return poly5(coeff, located.offset) + + poly5_grad(coeff, located.offset) * located.extrapolate_delta; +} + +double expected_table_grad(const LocatedXx& located) { + return poly5_grad(&kTable[located.table_idx * 6], located.offset); +} + +double se_a_value(const double xx) { + std::vector out(4 * kLastLayerSize); + const std::vector em_x = {xx}; + const std::vector em = {1.0, 0.0, 0.0, 0.0}; + deepmd::tabulate_fusion_se_a_cpu( + out.data(), kTable.data(), kTableInfo.data(), em_x.data(), em.data(), + nullptr, 1, 1, kLastLayerSize); + return out[0]; +} + +double se_a_grad(const double xx) { + std::vector dy_dem_x(1); + std::vector dy_dem(4); + std::vector dy_dtwo(1); + const std::vector em_x = {xx}; + const std::vector em = {1.0, 0.0, 0.0, 0.0}; + const std::vector dy = {1.0, 0.0, 0.0, 0.0}; + deepmd::tabulate_fusion_se_a_grad_cpu( + dy_dem_x.data(), dy_dem.data(), dy_dtwo.data(), kTable.data(), + kTableInfo.data(), em_x.data(), em.data(), nullptr, dy.data(), 1, 1, + kLastLayerSize); + return dy_dem_x[0]; +} + +double se_a_value_with_two_embed(const double xx) { + std::vector out(4 * kLastLayerSize); + const std::vector em_x = {xx}; + const std::vector em = {1.0, 0.0, 0.0, 0.0}; + const std::vector two_embed = {kTwoEmbed}; + deepmd::tabulate_fusion_se_a_cpu( + out.data(), kTable.data(), kTableInfo.data(), em_x.data(), em.data(), + two_embed.data(), 1, 1, kLastLayerSize); + return out[0]; +} + +void se_a_grad_with_two_embed(const double xx, + double& dy_dem_x, + double& dy_dtwo) { + std::vector dy_dem_x_vec(1); + std::vector dy_dem(4); + std::vector dy_dtwo_vec(1); + const std::vector em_x = {xx}; + const std::vector em = {1.0, 0.0, 0.0, 0.0}; + const std::vector two_embed = {kTwoEmbed}; + const std::vector dy = {1.0, 0.0, 0.0, 0.0}; + deepmd::tabulate_fusion_se_a_grad_cpu( + dy_dem_x_vec.data(), dy_dem.data(), dy_dtwo_vec.data(), kTable.data(), + kTableInfo.data(), em_x.data(), em.data(), two_embed.data(), dy.data(), 1, + 1, kLastLayerSize); + dy_dem_x = dy_dem_x_vec[0]; + dy_dtwo = dy_dtwo_vec[0]; +} + +double se_a_dem_x_grad_with_two_embed(const double xx) { + double dy_dem_x = 0.0; + double dy_dtwo = 0.0; + se_a_grad_with_two_embed(xx, dy_dem_x, dy_dtwo); + return dy_dem_x; +} + +double se_r_value(const double xx, const std::vector& table_info) { + std::vector out(1); + const std::vector em = {xx}; + deepmd::tabulate_fusion_se_r_cpu(out.data(), kTable.data(), table_info.data(), + em.data(), 1, 1, kLastLayerSize); + return out[0]; +} + +double se_r_value(const double xx) { return se_r_value(xx, kTableInfo); } + +double se_r_grad(const double xx, const std::vector& table_info) { + std::vector dy_dem(1); + const std::vector em = {xx}; + const std::vector dy = {1.0}; + deepmd::tabulate_fusion_se_r_grad_cpu(dy_dem.data(), kTable.data(), + table_info.data(), em.data(), dy.data(), + 1, 1, kLastLayerSize); + return dy_dem[0]; +} + +double se_r_grad(const double xx) { return se_r_grad(xx, kTableInfo); } + +double se_t_value(const double xx) { + std::vector out(1); + const std::vector em_x = {xx}; + const std::vector em = {xx}; + deepmd::tabulate_fusion_se_t_cpu(out.data(), kTable.data(), kTableInfo.data(), + em_x.data(), em.data(), 1, 1, 1, + kLastLayerSize); + return out[0]; +} + +void se_t_grad(const double xx, double& dy_dem_x, double& dy_dem) { + std::vector dy_dem_x_vec(1); + std::vector dy_dem_vec(1); + const std::vector em_x = {xx}; + const std::vector em = {xx}; + const std::vector dy = {1.0}; + deepmd::tabulate_fusion_se_t_grad_cpu( + dy_dem_x_vec.data(), dy_dem_vec.data(), kTable.data(), kTableInfo.data(), + em_x.data(), em.data(), dy.data(), 1, 1, 1, kLastLayerSize); + dy_dem_x = dy_dem_x_vec[0]; + dy_dem = dy_dem_vec[0]; +} + +double se_t_tebd_value(const double xx) { + std::vector out(1); + const std::vector em_x = {xx}; + const std::vector em = {xx}; + deepmd::tabulate_fusion_se_t_tebd_cpu(out.data(), kTable.data(), + kTableInfo.data(), em_x.data(), + em.data(), 1, 1, 1, kLastLayerSize); + return out[0]; +} + +double se_t_tebd_grad(const double xx) { + std::vector dy_dem_x(1); + const std::vector em_x = {xx}; + const std::vector em = {xx}; + const std::vector dy = {1.0}; + deepmd::tabulate_fusion_se_t_tebd_grad_cpu( + dy_dem_x.data(), kTable.data(), kTableInfo.data(), em_x.data(), em.data(), + dy.data(), 1, 1, 1, kLastLayerSize); + return dy_dem_x[0]; +} + +double dot_vector(const std::vector& lhs, + const std::vector& rhs) { + double value = 0.0; + for (std::size_t ii = 0; ii < lhs.size(); ++ii) { + value += lhs[ii] * rhs[ii]; + } + return value; +} + +double se_a_grad_projection(const double xx, const std::vector& dy) { + std::vector dy_dem_x(1); + std::vector dy_dem(4); + std::vector dy_dtwo(1); + const std::vector em_x = {xx}; + const std::vector em = {1.0, -0.25, 0.5, -0.75}; + const std::vector dz_dy_dem_x = {0.6}; + const std::vector dz_dy_dem = {0.2, -0.3, 0.4, -0.5}; + deepmd::tabulate_fusion_se_a_grad_cpu( + dy_dem_x.data(), dy_dem.data(), dy_dtwo.data(), kTable.data(), + kTableInfo.data(), em_x.data(), em.data(), nullptr, dy.data(), 1, 1, + kLastLayerSize); + return dot_vector(dy_dem_x, dz_dy_dem_x) + dot_vector(dy_dem, dz_dy_dem); +} + +std::vector se_a_grad_grad_dy(const double xx) { + std::vector dz_dy(4 * kLastLayerSize); + const std::vector em_x = {xx}; + const std::vector em = {1.0, -0.25, 0.5, -0.75}; + const std::vector dz_dy_dem_x = {0.6}; + const std::vector dz_dy_dem = {0.2, -0.3, 0.4, -0.5}; + const std::vector dz_dy_dtwo(1); + deepmd::tabulate_fusion_se_a_grad_grad_cpu( + dz_dy.data(), kTable.data(), kTableInfo.data(), em_x.data(), em.data(), + nullptr, dz_dy_dem_x.data(), dz_dy_dem.data(), dz_dy_dtwo.data(), 1, 1, + kLastLayerSize); + return dz_dy; +} + +double se_r_grad_projection(const double xx, const std::vector& dy) { + std::vector dy_dem(1); + const std::vector em = {xx}; + const std::vector dz_dy_dem = {0.6}; + deepmd::tabulate_fusion_se_r_grad_cpu(dy_dem.data(), kTable.data(), + kTableInfo.data(), em.data(), dy.data(), + 1, 1, kLastLayerSize); + return dot_vector(dy_dem, dz_dy_dem); +} + +std::vector se_r_grad_grad_dy(const double xx) { + std::vector dz_dy(1); + const std::vector em = {xx}; + const std::vector dz_dy_dem = {0.6}; + deepmd::tabulate_fusion_se_r_grad_grad_cpu( + dz_dy.data(), kTable.data(), kTableInfo.data(), em.data(), + dz_dy_dem.data(), 1, 1, kLastLayerSize); + return dz_dy; +} + +double se_t_grad_projection(const double xx, const std::vector& dy) { + std::vector dy_dem_x(1); + std::vector dy_dem(1); + const std::vector em_x = {xx}; + const std::vector em = {xx}; + const std::vector dz_dy_dem_x = {0.6}; + const std::vector dz_dy_dem = {-0.4}; + deepmd::tabulate_fusion_se_t_grad_cpu( + dy_dem_x.data(), dy_dem.data(), kTable.data(), kTableInfo.data(), + em_x.data(), em.data(), dy.data(), 1, 1, 1, kLastLayerSize); + return dot_vector(dy_dem_x, dz_dy_dem_x) + dot_vector(dy_dem, dz_dy_dem); +} + +std::vector se_t_grad_grad_dy(const double xx) { + std::vector dz_dy(1); + const std::vector em_x = {xx}; + const std::vector em = {xx}; + const std::vector dz_dy_dem_x = {0.6}; + const std::vector dz_dy_dem = {-0.4}; + deepmd::tabulate_fusion_se_t_grad_grad_cpu( + dz_dy.data(), kTable.data(), kTableInfo.data(), em_x.data(), em.data(), + dz_dy_dem_x.data(), dz_dy_dem.data(), 1, 1, 1, kLastLayerSize); + return dz_dy; +} + +double se_t_tebd_grad_projection(const double xx, + const std::vector& dy) { + std::vector dy_dem_x(1); + const std::vector em_x = {xx}; + const std::vector em = {xx}; + const std::vector dz_dy_dem_x = {0.6}; + deepmd::tabulate_fusion_se_t_tebd_grad_cpu( + dy_dem_x.data(), kTable.data(), kTableInfo.data(), em_x.data(), em.data(), + dy.data(), 1, 1, 1, kLastLayerSize); + return dot_vector(dy_dem_x, dz_dy_dem_x); +} + +std::vector se_t_tebd_grad_grad_dy(const double xx) { + std::vector dz_dy(1); + const std::vector em_x = {xx}; + const std::vector em = {xx}; + const std::vector dz_dy_dem_x = {0.6}; + deepmd::tabulate_fusion_se_t_tebd_grad_grad_cpu( + dz_dy.data(), kTable.data(), kTableInfo.data(), em_x.data(), em.data(), + dz_dy_dem_x.data(), 1, 1, 1, kLastLayerSize); + return dz_dy; +} + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +double se_a_value_gpu(const double xx) { + std::vector out(4 * kLastLayerSize); + const std::vector em_x = {xx}; + const std::vector em = {1.0, 0.0, 0.0, 0.0}; + double *out_dev = nullptr, *table_dev = nullptr, *em_x_dev = nullptr, + *em_dev = nullptr; + deepmd::malloc_device_memory_sync(out_dev, out); + deepmd::malloc_device_memory_sync(table_dev, kTable); + deepmd::malloc_device_memory_sync(em_x_dev, em_x); + deepmd::malloc_device_memory_sync(em_dev, em); + deepmd::tabulate_fusion_se_a_gpu(out_dev, table_dev, + kTableInfo.data(), em_x_dev, em_dev, + nullptr, 1, 1, kLastLayerSize); + deepmd::memcpy_device_to_host(out_dev, out); + deepmd::delete_device_memory(out_dev); + deepmd::delete_device_memory(table_dev); + deepmd::delete_device_memory(em_x_dev); + deepmd::delete_device_memory(em_dev); + return out[0]; +} + +double se_a_grad_gpu(const double xx) { + std::vector dy_dem_x(1); + std::vector dy_dem(4); + const std::vector em_x = {xx}; + const std::vector em = {1.0, 0.0, 0.0, 0.0}; + const std::vector dy = {1.0, 0.0, 0.0, 0.0}; + double *dy_dem_x_dev = nullptr, *dy_dem_dev = nullptr, *table_dev = nullptr, + *em_x_dev = nullptr, *em_dev = nullptr, *dy_dev = nullptr; + deepmd::malloc_device_memory_sync(dy_dem_x_dev, dy_dem_x); + deepmd::malloc_device_memory_sync(dy_dem_dev, dy_dem); + deepmd::malloc_device_memory_sync(table_dev, kTable); + deepmd::malloc_device_memory_sync(em_x_dev, em_x); + deepmd::malloc_device_memory_sync(em_dev, em); + deepmd::malloc_device_memory_sync(dy_dev, dy); + deepmd::tabulate_fusion_se_a_grad_gpu( + dy_dem_x_dev, dy_dem_dev, nullptr, table_dev, kTableInfo.data(), em_x_dev, + em_dev, nullptr, dy_dev, 1, 1, kLastLayerSize); + deepmd::memcpy_device_to_host(dy_dem_x_dev, dy_dem_x); + deepmd::delete_device_memory(dy_dem_x_dev); + deepmd::delete_device_memory(dy_dem_dev); + deepmd::delete_device_memory(table_dev); + deepmd::delete_device_memory(em_x_dev); + deepmd::delete_device_memory(em_dev); + deepmd::delete_device_memory(dy_dev); + return dy_dem_x[0]; +} + +double se_r_value_gpu(const double xx) { + std::vector out(1); + const std::vector em = {xx}; + double *out_dev = nullptr, *table_dev = nullptr, *em_dev = nullptr; + deepmd::malloc_device_memory_sync(out_dev, out); + deepmd::malloc_device_memory_sync(table_dev, kTable); + deepmd::malloc_device_memory_sync(em_dev, em); + deepmd::tabulate_fusion_se_r_gpu( + out_dev, table_dev, kTableInfo.data(), em_dev, 1, 1, kLastLayerSize); + deepmd::memcpy_device_to_host(out_dev, out); + deepmd::delete_device_memory(out_dev); + deepmd::delete_device_memory(table_dev); + deepmd::delete_device_memory(em_dev); + return out[0]; +} + +double se_r_grad_gpu(const double xx) { + std::vector dy_dem(1); + const std::vector em = {xx}; + const std::vector dy = {1.0}; + double *dy_dem_dev = nullptr, *table_dev = nullptr, *em_dev = nullptr, + *dy_dev = nullptr; + deepmd::malloc_device_memory_sync(dy_dem_dev, dy_dem); + deepmd::malloc_device_memory_sync(table_dev, kTable); + deepmd::malloc_device_memory_sync(em_dev, em); + deepmd::malloc_device_memory_sync(dy_dev, dy); + deepmd::tabulate_fusion_se_r_grad_gpu(dy_dem_dev, table_dev, + kTableInfo.data(), em_dev, + dy_dev, 1, 1, kLastLayerSize); + deepmd::memcpy_device_to_host(dy_dem_dev, dy_dem); + deepmd::delete_device_memory(dy_dem_dev); + deepmd::delete_device_memory(table_dev); + deepmd::delete_device_memory(em_dev); + deepmd::delete_device_memory(dy_dev); + return dy_dem[0]; +} + +double se_t_value_gpu(const double xx) { + std::vector out(1); + const std::vector em_x = {xx}; + const std::vector em = {xx}; + double *out_dev = nullptr, *table_dev = nullptr, *em_x_dev = nullptr, + *em_dev = nullptr; + deepmd::malloc_device_memory_sync(out_dev, out); + deepmd::malloc_device_memory_sync(table_dev, kTable); + deepmd::malloc_device_memory_sync(em_x_dev, em_x); + deepmd::malloc_device_memory_sync(em_dev, em); + deepmd::tabulate_fusion_se_t_gpu(out_dev, table_dev, + kTableInfo.data(), em_x_dev, em_dev, + 1, 1, 1, kLastLayerSize); + deepmd::memcpy_device_to_host(out_dev, out); + deepmd::delete_device_memory(out_dev); + deepmd::delete_device_memory(table_dev); + deepmd::delete_device_memory(em_x_dev); + deepmd::delete_device_memory(em_dev); + return out[0]; +} + +void se_t_grad_gpu(const double xx, double& dy_dem_x, double& dy_dem) { + std::vector dy_dem_x_vec(1); + std::vector dy_dem_vec(1); + const std::vector em_x = {xx}; + const std::vector em = {xx}; + const std::vector dy = {1.0}; + double *dy_dem_x_dev = nullptr, *dy_dem_dev = nullptr, *table_dev = nullptr, + *em_x_dev = nullptr, *em_dev = nullptr, *dy_dev = nullptr; + deepmd::malloc_device_memory_sync(dy_dem_x_dev, dy_dem_x_vec); + deepmd::malloc_device_memory_sync(dy_dem_dev, dy_dem_vec); + deepmd::malloc_device_memory_sync(table_dev, kTable); + deepmd::malloc_device_memory_sync(em_x_dev, em_x); + deepmd::malloc_device_memory_sync(em_dev, em); + deepmd::malloc_device_memory_sync(dy_dev, dy); + deepmd::tabulate_fusion_se_t_grad_gpu( + dy_dem_x_dev, dy_dem_dev, table_dev, kTableInfo.data(), em_x_dev, em_dev, + dy_dev, 1, 1, 1, kLastLayerSize); + deepmd::memcpy_device_to_host(dy_dem_x_dev, dy_dem_x_vec); + deepmd::memcpy_device_to_host(dy_dem_dev, dy_dem_vec); + deepmd::delete_device_memory(dy_dem_x_dev); + deepmd::delete_device_memory(dy_dem_dev); + deepmd::delete_device_memory(table_dev); + deepmd::delete_device_memory(em_x_dev); + deepmd::delete_device_memory(em_dev); + deepmd::delete_device_memory(dy_dev); + dy_dem_x = dy_dem_x_vec[0]; + dy_dem = dy_dem_vec[0]; +} + +double se_t_tebd_value_gpu(const double xx) { + std::vector out(1); + const std::vector em_x = {xx}; + const std::vector em = {xx}; + double *out_dev = nullptr, *table_dev = nullptr, *em_x_dev = nullptr, + *em_dev = nullptr; + deepmd::malloc_device_memory_sync(out_dev, out); + deepmd::malloc_device_memory_sync(table_dev, kTable); + deepmd::malloc_device_memory_sync(em_x_dev, em_x); + deepmd::malloc_device_memory_sync(em_dev, em); + deepmd::tabulate_fusion_se_t_tebd_gpu( + out_dev, table_dev, kTableInfo.data(), em_x_dev, em_dev, 1, 1, 1, + kLastLayerSize); + deepmd::memcpy_device_to_host(out_dev, out); + deepmd::delete_device_memory(out_dev); + deepmd::delete_device_memory(table_dev); + deepmd::delete_device_memory(em_x_dev); + deepmd::delete_device_memory(em_dev); + return out[0]; +} + +double se_t_tebd_grad_gpu(const double xx) { + std::vector dy_dem_x(1); + const std::vector em_x = {xx}; + const std::vector em = {xx}; + const std::vector dy = {1.0}; + double *dy_dem_x_dev = nullptr, *table_dev = nullptr, *em_x_dev = nullptr, + *em_dev = nullptr, *dy_dev = nullptr; + deepmd::malloc_device_memory_sync(dy_dem_x_dev, dy_dem_x); + deepmd::malloc_device_memory_sync(table_dev, kTable); + deepmd::malloc_device_memory_sync(em_x_dev, em_x); + deepmd::malloc_device_memory_sync(em_dev, em); + deepmd::malloc_device_memory_sync(dy_dev, dy); + deepmd::tabulate_fusion_se_t_tebd_grad_gpu( + dy_dem_x_dev, table_dev, kTableInfo.data(), em_x_dev, em_dev, dy_dev, 1, + 1, 1, kLastLayerSize); + deepmd::memcpy_device_to_host(dy_dem_x_dev, dy_dem_x); + deepmd::delete_device_memory(dy_dem_x_dev); + deepmd::delete_device_memory(table_dev); + deepmd::delete_device_memory(em_x_dev); + deepmd::delete_device_memory(em_dev); + deepmd::delete_device_memory(dy_dev); + return dy_dem_x[0]; +} + +double se_a_grad_projection_gpu(const double xx, + const std::vector& dy) { + std::vector dy_dem_x(1); + std::vector dy_dem(4); + std::vector table = kTable; + std::vector em_x = {xx}; + std::vector em = {1.0, -0.25, 0.5, -0.75}; + std::vector dy_host = dy; + const std::vector dz_dy_dem_x = {0.6}; + const std::vector dz_dy_dem = {0.2, -0.3, 0.4, -0.5}; + double *dy_dem_x_dev = nullptr, *dy_dem_dev = nullptr, *table_dev = nullptr, + *em_x_dev = nullptr, *em_dev = nullptr, *dy_dev = nullptr; + deepmd::malloc_device_memory_sync(dy_dem_x_dev, dy_dem_x); + deepmd::malloc_device_memory_sync(dy_dem_dev, dy_dem); + deepmd::malloc_device_memory_sync(table_dev, table); + deepmd::malloc_device_memory_sync(em_x_dev, em_x); + deepmd::malloc_device_memory_sync(em_dev, em); + deepmd::malloc_device_memory_sync(dy_dev, dy_host); + deepmd::tabulate_fusion_se_a_grad_gpu( + dy_dem_x_dev, dy_dem_dev, nullptr, table_dev, kTableInfo.data(), em_x_dev, + em_dev, nullptr, dy_dev, 1, 1, kLastLayerSize); + deepmd::memcpy_device_to_host(dy_dem_x_dev, dy_dem_x); + deepmd::memcpy_device_to_host(dy_dem_dev, dy_dem); + deepmd::delete_device_memory(dy_dem_x_dev); + deepmd::delete_device_memory(dy_dem_dev); + deepmd::delete_device_memory(table_dev); + deepmd::delete_device_memory(em_x_dev); + deepmd::delete_device_memory(em_dev); + deepmd::delete_device_memory(dy_dev); + return dot_vector(dy_dem_x, dz_dy_dem_x) + dot_vector(dy_dem, dz_dy_dem); +} + +std::vector se_a_grad_grad_dy_gpu(const double xx) { + std::vector dz_dy(4 * kLastLayerSize); + std::vector table = kTable; + std::vector em_x = {xx}; + std::vector em = {1.0, -0.25, 0.5, -0.75}; + std::vector dz_dy_dem_x = {0.6}; + std::vector dz_dy_dem = {0.2, -0.3, 0.4, -0.5}; + double *dz_dy_dev = nullptr, *table_dev = nullptr, *em_x_dev = nullptr, + *em_dev = nullptr, *dz_dy_dem_x_dev = nullptr, + *dz_dy_dem_dev = nullptr; + deepmd::malloc_device_memory_sync(dz_dy_dev, dz_dy); + deepmd::malloc_device_memory_sync(table_dev, table); + deepmd::malloc_device_memory_sync(em_x_dev, em_x); + deepmd::malloc_device_memory_sync(em_dev, em); + deepmd::malloc_device_memory_sync(dz_dy_dem_x_dev, dz_dy_dem_x); + deepmd::malloc_device_memory_sync(dz_dy_dem_dev, dz_dy_dem); + deepmd::tabulate_fusion_se_a_grad_grad_gpu( + dz_dy_dev, table_dev, kTableInfo.data(), em_x_dev, em_dev, nullptr, + dz_dy_dem_x_dev, dz_dy_dem_dev, nullptr, 1, 1, kLastLayerSize); + deepmd::memcpy_device_to_host(dz_dy_dev, dz_dy); + deepmd::delete_device_memory(dz_dy_dev); + deepmd::delete_device_memory(table_dev); + deepmd::delete_device_memory(em_x_dev); + deepmd::delete_device_memory(em_dev); + deepmd::delete_device_memory(dz_dy_dem_x_dev); + deepmd::delete_device_memory(dz_dy_dem_dev); + return dz_dy; +} + +double se_r_grad_projection_gpu(const double xx, + const std::vector& dy) { + std::vector dy_dem(1); + std::vector table = kTable; + std::vector em = {xx}; + std::vector dy_host = dy; + const std::vector dz_dy_dem = {0.6}; + double *dy_dem_dev = nullptr, *table_dev = nullptr, *em_dev = nullptr, + *dy_dev = nullptr; + deepmd::malloc_device_memory_sync(dy_dem_dev, dy_dem); + deepmd::malloc_device_memory_sync(table_dev, table); + deepmd::malloc_device_memory_sync(em_dev, em); + deepmd::malloc_device_memory_sync(dy_dev, dy_host); + deepmd::tabulate_fusion_se_r_grad_gpu(dy_dem_dev, table_dev, + kTableInfo.data(), em_dev, + dy_dev, 1, 1, kLastLayerSize); + deepmd::memcpy_device_to_host(dy_dem_dev, dy_dem); + deepmd::delete_device_memory(dy_dem_dev); + deepmd::delete_device_memory(table_dev); + deepmd::delete_device_memory(em_dev); + deepmd::delete_device_memory(dy_dev); + return dot_vector(dy_dem, dz_dy_dem); +} + +std::vector se_r_grad_grad_dy_gpu(const double xx) { + std::vector dz_dy(1); + std::vector table = kTable; + std::vector em = {xx}; + std::vector dz_dy_dem = {0.6}; + double *dz_dy_dev = nullptr, *table_dev = nullptr, *em_dev = nullptr, + *dz_dy_dem_dev = nullptr; + deepmd::malloc_device_memory_sync(dz_dy_dev, dz_dy); + deepmd::malloc_device_memory_sync(table_dev, table); + deepmd::malloc_device_memory_sync(em_dev, em); + deepmd::malloc_device_memory_sync(dz_dy_dem_dev, dz_dy_dem); + deepmd::tabulate_fusion_se_r_grad_grad_gpu( + dz_dy_dev, table_dev, kTableInfo.data(), em_dev, dz_dy_dem_dev, 1, 1, + kLastLayerSize); + deepmd::memcpy_device_to_host(dz_dy_dev, dz_dy); + deepmd::delete_device_memory(dz_dy_dev); + deepmd::delete_device_memory(table_dev); + deepmd::delete_device_memory(em_dev); + deepmd::delete_device_memory(dz_dy_dem_dev); + return dz_dy; +} + +double se_t_grad_projection_gpu(const double xx, + const std::vector& dy) { + std::vector dy_dem_x(1); + std::vector dy_dem(1); + std::vector table = kTable; + std::vector em_x = {xx}; + std::vector em = {xx}; + std::vector dy_host = dy; + const std::vector dz_dy_dem_x = {0.6}; + const std::vector dz_dy_dem = {-0.4}; + double *dy_dem_x_dev = nullptr, *dy_dem_dev = nullptr, *table_dev = nullptr, + *em_x_dev = nullptr, *em_dev = nullptr, *dy_dev = nullptr; + deepmd::malloc_device_memory_sync(dy_dem_x_dev, dy_dem_x); + deepmd::malloc_device_memory_sync(dy_dem_dev, dy_dem); + deepmd::malloc_device_memory_sync(table_dev, table); + deepmd::malloc_device_memory_sync(em_x_dev, em_x); + deepmd::malloc_device_memory_sync(em_dev, em); + deepmd::malloc_device_memory_sync(dy_dev, dy_host); + deepmd::tabulate_fusion_se_t_grad_gpu( + dy_dem_x_dev, dy_dem_dev, table_dev, kTableInfo.data(), em_x_dev, em_dev, + dy_dev, 1, 1, 1, kLastLayerSize); + deepmd::memcpy_device_to_host(dy_dem_x_dev, dy_dem_x); + deepmd::memcpy_device_to_host(dy_dem_dev, dy_dem); + deepmd::delete_device_memory(dy_dem_x_dev); + deepmd::delete_device_memory(dy_dem_dev); + deepmd::delete_device_memory(table_dev); + deepmd::delete_device_memory(em_x_dev); + deepmd::delete_device_memory(em_dev); + deepmd::delete_device_memory(dy_dev); + return dot_vector(dy_dem_x, dz_dy_dem_x) + dot_vector(dy_dem, dz_dy_dem); +} + +std::vector se_t_grad_grad_dy_gpu(const double xx) { + std::vector dz_dy(1); + std::vector table = kTable; + std::vector em_x = {xx}; + std::vector em = {xx}; + std::vector dz_dy_dem_x = {0.6}; + std::vector dz_dy_dem = {-0.4}; + double *dz_dy_dev = nullptr, *table_dev = nullptr, *em_x_dev = nullptr, + *em_dev = nullptr, *dz_dy_dem_x_dev = nullptr, + *dz_dy_dem_dev = nullptr; + deepmd::malloc_device_memory_sync(dz_dy_dev, dz_dy); + deepmd::malloc_device_memory_sync(table_dev, table); + deepmd::malloc_device_memory_sync(em_x_dev, em_x); + deepmd::malloc_device_memory_sync(em_dev, em); + deepmd::malloc_device_memory_sync(dz_dy_dem_x_dev, dz_dy_dem_x); + deepmd::malloc_device_memory_sync(dz_dy_dem_dev, dz_dy_dem); + deepmd::tabulate_fusion_se_t_grad_grad_gpu( + dz_dy_dev, table_dev, kTableInfo.data(), em_x_dev, em_dev, + dz_dy_dem_x_dev, dz_dy_dem_dev, 1, 1, 1, kLastLayerSize); + deepmd::memcpy_device_to_host(dz_dy_dev, dz_dy); + deepmd::delete_device_memory(dz_dy_dev); + deepmd::delete_device_memory(table_dev); + deepmd::delete_device_memory(em_x_dev); + deepmd::delete_device_memory(em_dev); + deepmd::delete_device_memory(dz_dy_dem_x_dev); + deepmd::delete_device_memory(dz_dy_dem_dev); + return dz_dy; +} + +double se_t_tebd_grad_projection_gpu(const double xx, + const std::vector& dy) { + std::vector dy_dem_x(1); + std::vector table = kTable; + std::vector em_x = {xx}; + std::vector em = {xx}; + std::vector dy_host = dy; + const std::vector dz_dy_dem_x = {0.6}; + double *dy_dem_x_dev = nullptr, *table_dev = nullptr, *em_x_dev = nullptr, + *em_dev = nullptr, *dy_dev = nullptr; + deepmd::malloc_device_memory_sync(dy_dem_x_dev, dy_dem_x); + deepmd::malloc_device_memory_sync(table_dev, table); + deepmd::malloc_device_memory_sync(em_x_dev, em_x); + deepmd::malloc_device_memory_sync(em_dev, em); + deepmd::malloc_device_memory_sync(dy_dev, dy_host); + deepmd::tabulate_fusion_se_t_tebd_grad_gpu( + dy_dem_x_dev, table_dev, kTableInfo.data(), em_x_dev, em_dev, dy_dev, 1, + 1, 1, kLastLayerSize); + deepmd::memcpy_device_to_host(dy_dem_x_dev, dy_dem_x); + deepmd::delete_device_memory(dy_dem_x_dev); + deepmd::delete_device_memory(table_dev); + deepmd::delete_device_memory(em_x_dev); + deepmd::delete_device_memory(em_dev); + deepmd::delete_device_memory(dy_dev); + return dot_vector(dy_dem_x, dz_dy_dem_x); +} + +std::vector se_t_tebd_grad_grad_dy_gpu(const double xx) { + std::vector dz_dy(1); + std::vector table = kTable; + std::vector em_x = {xx}; + std::vector em = {xx}; + std::vector dz_dy_dem_x = {0.6}; + double *dz_dy_dev = nullptr, *table_dev = nullptr, *em_x_dev = nullptr, + *em_dev = nullptr, *dz_dy_dem_x_dev = nullptr; + deepmd::malloc_device_memory_sync(dz_dy_dev, dz_dy); + deepmd::malloc_device_memory_sync(table_dev, table); + deepmd::malloc_device_memory_sync(em_x_dev, em_x); + deepmd::malloc_device_memory_sync(em_dev, em); + deepmd::malloc_device_memory_sync(dz_dy_dem_x_dev, dz_dy_dem_x); + deepmd::tabulate_fusion_se_t_tebd_grad_grad_gpu( + dz_dy_dev, table_dev, kTableInfo.data(), em_x_dev, em_dev, + dz_dy_dem_x_dev, 1, 1, 1, kLastLayerSize); + deepmd::memcpy_device_to_host(dz_dy_dev, dz_dy); + deepmd::delete_device_memory(dz_dy_dev); + deepmd::delete_device_memory(table_dev); + deepmd::delete_device_memory(em_x_dev); + deepmd::delete_device_memory(em_dev); + deepmd::delete_device_memory(dz_dy_dem_x_dev); + return dz_dy; +} +#endif + +double central_diff(double (*fn)(double), const double xx) { + return (fn(xx + kFiniteDiffStep) - fn(xx - kFiniteDiffStep)) / + (2.0 * kFiniteDiffStep); +} + +double grad_central_diff(double (*fn)(double), const double xx) { + return (fn(xx + kFiniteDiffStep) - fn(xx - kFiniteDiffStep)) / + (2.0 * kFiniteDiffStep); +} + +void expect_linear_tail(double (*value_fn)(double), + double (*grad_fn)(double), + const double xx, + const LocatedXx& located) { + EXPECT_NEAR(value_fn(xx), expected_table_value(located), kTolerance); + EXPECT_NEAR(grad_fn(xx), expected_table_grad(located), kTolerance); + EXPECT_NEAR(central_diff(value_fn, xx), grad_fn(xx), 1e-8); + EXPECT_NEAR(grad_central_diff(grad_fn, xx), 0.0, 1e-10); +} + +void expect_boundary(double (*value_fn)(double), + double (*grad_fn)(double), + const double xx, + const LocatedXx& located) { + EXPECT_NEAR(value_fn(xx), expected_table_value(located), kTolerance); + EXPECT_NEAR(grad_fn(xx), expected_table_grad(located), kTolerance); +} + +void expect_se_t_lookup_tail(double (*value_fn)(double), + void (*grad_fn)(double, double&, double&), + const double xx) { + const LocatedXx located = locate_se_t(xx); + const double expected_value = expected_table_value(located); + const double expected_grad = expected_table_grad(located); + double dy_dem_x = 0.0; + double dy_dem = 0.0; + grad_fn(xx, dy_dem_x, dy_dem); + EXPECT_NEAR(value_fn(xx), xx * expected_value, kTolerance); + EXPECT_NEAR(dy_dem_x, xx * expected_grad, kTolerance); + EXPECT_NEAR(dy_dem, expected_value, kTolerance); + EXPECT_NEAR(central_diff(value_fn, xx), dy_dem_x + dy_dem, 1e-5); +} + +using GradGradFn = std::vector (*)(double); +using ProjectionFn = double (*)(double, const std::vector&); + +void expect_grad_grad_matches_dy_finite_diff(GradGradFn grad_grad_fn, + ProjectionFn projection_fn, + const double xx, + std::vector dy) { + const std::vector actual = grad_grad_fn(xx); + ASSERT_EQ(actual.size(), dy.size()); + for (std::size_t ii = 0; ii < dy.size(); ++ii) { + std::vector dy_plus = dy; + std::vector dy_minus = dy; + dy_plus[ii] += kFiniteDiffStep; + dy_minus[ii] -= kFiniteDiffStep; + const double reference = + (projection_fn(xx, dy_plus) - projection_fn(xx, dy_minus)) / + (2.0 * kFiniteDiffStep); + EXPECT_NEAR(actual[ii], reference, 1e-8); + } +} + +} // namespace + +TEST(TabulateExtrapolate, SeAUsesC1LinearTails) { + expect_linear_tail(se_a_value, se_a_grad, kLower - 0.25, + locate_se_a_or_r(kLower - 0.25)); + expect_boundary(se_a_value, se_a_grad, kMax, locate_se_a_or_r(kMax)); + expect_linear_tail(se_a_value, se_a_grad, kMax + 0.25, + locate_se_a_or_r(kMax + 0.25)); +} + +TEST(TabulateExtrapolate, SeAWithTwoEmbedUsesC1LinearTail) { + const double xx = kMax + 0.25; + const LocatedXx located = locate_se_a_or_r(xx); + double dy_dem_x = 0.0; + double dy_dtwo = 0.0; + se_a_grad_with_two_embed(xx, dy_dem_x, dy_dtwo); + EXPECT_NEAR(se_a_value_with_two_embed(xx), + (1.0 + kTwoEmbed) * expected_table_value(located), kTolerance); + EXPECT_NEAR(dy_dem_x, (1.0 + kTwoEmbed) * expected_table_grad(located), + kTolerance); + EXPECT_NEAR(dy_dtwo, expected_table_value(located), kTolerance); + EXPECT_NEAR(central_diff(se_a_value_with_two_embed, xx), dy_dem_x, 1e-8); + EXPECT_NEAR(grad_central_diff(se_a_dem_x_grad_with_two_embed, xx), 0.0, + 1e-10); +} + +TEST(TabulateExtrapolate, SeRUsesC1LinearTails) { + expect_linear_tail(se_r_value, se_r_grad, kLower - 0.25, + locate_se_a_or_r(kLower - 0.25)); + expect_boundary(se_r_value, se_r_grad, kMax, locate_se_a_or_r(kMax)); + expect_linear_tail(se_r_value, se_r_grad, kMax + 0.25, + locate_se_a_or_r(kMax + 0.25)); +} + +TEST(TabulateExtrapolate, SeROffGridMaxUsesBoundarySegment) { + const double max = kOffGridTableInfo[2]; + for (const double xx : {max, max + 0.25}) { + const LocatedXx located = locate_se_a_or_r(xx, kOffGridTableInfo); + EXPECT_NEAR(se_r_value(xx, kOffGridTableInfo), + expected_table_value(located), kTolerance); + EXPECT_NEAR(se_r_grad(xx, kOffGridTableInfo), expected_table_grad(located), + kTolerance); + } +} + +TEST(TabulateExtrapolate, GradGradKernelsMatchDyFiniteDifferenceInTails) { + for (const double xx : {kLower - 0.25, kMax + 0.25}) { + expect_grad_grad_matches_dy_finite_diff( + se_a_grad_grad_dy, se_a_grad_projection, xx, {0.1, -0.2, 0.3, -0.4}); + expect_grad_grad_matches_dy_finite_diff(se_r_grad_grad_dy, + se_r_grad_projection, xx, {0.1}); + } + for (const double xx : {kMin - 0.25, kMax + 0.25}) { + expect_grad_grad_matches_dy_finite_diff(se_t_grad_grad_dy, + se_t_grad_projection, xx, {0.1}); + expect_grad_grad_matches_dy_finite_diff( + se_t_tebd_grad_grad_dy, se_t_tebd_grad_projection, xx, {0.1}); + } +} + +TEST(TabulateExtrapolate, SeTUsesLinearLookupTails) { + for (const double xx : {kMin - 0.25, kMin, kMax, kMax + 0.25}) { + expect_se_t_lookup_tail(se_t_value, se_t_grad, xx); + } +} + +TEST(TabulateExtrapolate, SeTTebdUsesC1LinearTails) { + expect_linear_tail(se_t_tebd_value, se_t_tebd_grad, kMin - 0.25, + locate_se_t(kMin - 0.25)); + expect_boundary(se_t_tebd_value, se_t_tebd_grad, kMax, locate_se_t(kMax)); + expect_linear_tail(se_t_tebd_value, se_t_tebd_grad, kMax + 0.25, + locate_se_t(kMax + 0.25)); +} + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +TEST(TabulateExtrapolate, SeAGpuUsesC1LinearTails) { + expect_linear_tail(se_a_value_gpu, se_a_grad_gpu, kLower - 0.25, + locate_se_a_or_r(kLower - 0.25)); + expect_boundary(se_a_value_gpu, se_a_grad_gpu, kMax, locate_se_a_or_r(kMax)); + expect_linear_tail(se_a_value_gpu, se_a_grad_gpu, kMax + 0.25, + locate_se_a_or_r(kMax + 0.25)); +} + +TEST(TabulateExtrapolate, SeRGpuUsesC1LinearTails) { + expect_linear_tail(se_r_value_gpu, se_r_grad_gpu, kLower - 0.25, + locate_se_a_or_r(kLower - 0.25)); + expect_boundary(se_r_value_gpu, se_r_grad_gpu, kMax, locate_se_a_or_r(kMax)); + expect_linear_tail(se_r_value_gpu, se_r_grad_gpu, kMax + 0.25, + locate_se_a_or_r(kMax + 0.25)); +} + +TEST(TabulateExtrapolate, SeTGpuUsesLinearLookupTails) { + for (const double xx : {kMin - 0.25, kMin, kMax, kMax + 0.25}) { + expect_se_t_lookup_tail(se_t_value_gpu, se_t_grad_gpu, xx); + } +} + +TEST(TabulateExtrapolate, SeTTebdGpuUsesC1LinearTails) { + expect_linear_tail(se_t_tebd_value_gpu, se_t_tebd_grad_gpu, kMin - 0.25, + locate_se_t(kMin - 0.25)); + expect_boundary(se_t_tebd_value_gpu, se_t_tebd_grad_gpu, kMax, + locate_se_t(kMax)); + expect_linear_tail(se_t_tebd_value_gpu, se_t_tebd_grad_gpu, kMax + 0.25, + locate_se_t(kMax + 0.25)); +} + +TEST(TabulateExtrapolate, GpuGradGradKernelsMatchDyFiniteDifferenceInTails) { + for (const double xx : {kLower - 0.25, kMax + 0.25}) { + expect_grad_grad_matches_dy_finite_diff(se_a_grad_grad_dy_gpu, + se_a_grad_projection_gpu, xx, + {0.1, -0.2, 0.3, -0.4}); + expect_grad_grad_matches_dy_finite_diff( + se_r_grad_grad_dy_gpu, se_r_grad_projection_gpu, xx, {0.1}); + } + for (const double xx : {kMin - 0.25, kMax + 0.25}) { + expect_grad_grad_matches_dy_finite_diff( + se_t_grad_grad_dy_gpu, se_t_grad_projection_gpu, xx, {0.1}); + expect_grad_grad_matches_dy_finite_diff( + se_t_tebd_grad_grad_dy_gpu, se_t_tebd_grad_projection_gpu, xx, {0.1}); + } +} +#endif