From 8012ee9be7c55c51d2000f29c77b047cdafc7970 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 4 Aug 2021 00:05:52 -0400 Subject: [PATCH 1/6] Passing error to TF instead of exit This commit does three little things: (1) create an exception called `deepmd::deepmd_exception` (based on `std::runtime_error`); (2) throw this exception instead of `exit` or `std::runtime_error`; (3) catch this exception in the op, and pass to TF using `OP_REQUIRES_OK`. One more, the OOM error will raise ResourceExhausted, as the same as TF ops. The benifit of doing so is that the TF side and Python side can processing other things, catch the error, and print the traceback. This commit can also fix #802, where the Python didn't save the buffer to the file before exit. --- source/lib/include/SimulationRegion_Impl.h | 3 +- source/lib/include/errors.h | 20 +++++++++ source/lib/include/gpu_cuda.h | 9 ++-- source/lib/include/gpu_rocm.h | 5 ++- source/lib/src/fmt_nlist.cc | 3 +- source/lib/src/pair_tab.cc | 3 +- source/lib/src/prod_force.cc | 3 +- source/lib/src/prod_force_grad.cc | 3 +- source/lib/src/prod_virial.cc | 3 +- source/lib/src/prod_virial_grad.cc | 3 +- source/lib/src/region.cc | 3 +- source/op/descrpt.cc | 17 +++++++- source/op/descrpt_se_a_ef.cc | 17 +++++++- source/op/descrpt_se_a_ef_para.cc | 17 +++++++- source/op/descrpt_se_a_ef_vert.cc | 17 +++++++- source/op/ewald_recp.cc | 13 ++++++ source/op/gelu_multi_device.cc | 37 ++++++++++++++++ source/op/legacy/descrpt_se_a.cc | 4 +- source/op/legacy/descrpt_se_r.cc | 4 +- source/op/map_aparam.cc | 13 ++++++ source/op/neighbor_stat.cc | 17 +++++++- source/op/prod_env_mat_multi_device.cc | 31 +++++++++++++- source/op/prod_force.cc | 13 ++++++ source/op/prod_force_grad.cc | 13 ++++++ source/op/prod_force_grad_multi_device.cc | 25 +++++++++++ source/op/prod_force_multi_device.cc | 25 +++++++++++ source/op/prod_force_se_a_grad.cc | 13 ++++++ source/op/prod_force_se_r_grad.cc | 13 ++++++ source/op/prod_virial.cc | 13 ++++++ source/op/prod_virial_grad.cc | 13 ++++++ source/op/prod_virial_grad_multi_device.cc | 25 +++++++++++ source/op/prod_virial_multi_device.cc | 25 +++++++++++ source/op/prod_virial_se_a_grad.cc | 13 ++++++ source/op/prod_virial_se_r_grad.cc | 13 ++++++ source/op/soft_min.cc | 13 ++++++ source/op/soft_min_force.cc | 13 ++++++ source/op/soft_min_force_grad.cc | 13 ++++++ source/op/soft_min_virial.cc | 13 ++++++ source/op/soft_min_virial_grad.cc | 13 ++++++ source/op/tabulate_multi_device.cc | 25 +++++++++++ source/op/unaggregated_grad.cc | 49 ++++++++++++++++++++++ 41 files changed, 558 insertions(+), 30 deletions(-) create mode 100644 source/lib/include/errors.h diff --git a/source/lib/include/SimulationRegion_Impl.h b/source/lib/include/SimulationRegion_Impl.h index 5b7b8248fd..528402b7d6 100644 --- a/source/lib/include/SimulationRegion_Impl.h +++ b/source/lib/include/SimulationRegion_Impl.h @@ -6,6 +6,7 @@ #include #include #include +#include "errors.h" // using namespace std; @@ -502,7 +503,7 @@ computeVolume() boxt[0*3+2] * (boxt[1*3+0]*boxt[2*3+1] - boxt[2*3+0]*boxt[1*3+1]); volumei = static_cast(1.)/volume; if (volume < 0) { - throw std::runtime_error("Negative volume detected. Please make sure the simulation cell obeys the right-hand rule."); + throw deepmd::deepmd_exception("Negative volume detected. Please make sure the simulation cell obeys the right-hand rule."); } } diff --git a/source/lib/include/errors.h b/source/lib/include/errors.h new file mode 100644 index 0000000000..fe0a21fc50 --- /dev/null +++ b/source/lib/include/errors.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +namespace deepmd{ + struct + deepmd_exception: public std::runtime_error { + public: + deepmd_exception(): runtime_error("DeePMD-kit Error!") {}; + deepmd_exception(const std::string& msg): runtime_error(std::string("DeePMD-kit Error: ") + msg) {}; + }; + + struct + deepmd_exception_oom: public std::runtime_error{ + public: + deepmd_exception_oom(): runtime_error("DeePMD-kit OOM!") {}; + deepmd_exception_oom(const std::string& msg): runtime_error(std::string("DeePMD-kit OOM: ") + msg) {}; + }; +}; \ No newline at end of file diff --git a/source/lib/include/gpu_cuda.h b/source/lib/include/gpu_cuda.h index cd82ee4657..8a2b617c95 100644 --- a/source/lib/include/gpu_cuda.h +++ b/source/lib/include/gpu_cuda.h @@ -3,6 +3,7 @@ #include #include #include +#include "errors.h" #define GPU_MAX_NBOR_SIZE 4096 #define DPErrcheck(res) {DPAssert((res), __FILE__, __LINE__);} @@ -12,7 +13,6 @@ inline void DPAssert(cudaError_t code, const char *file, int line, bool abort=tr fprintf(stderr,"cuda assert: %s %s %d\n", cudaGetErrorString(code), file, line); if (code == 2) { // out of memory - // TODO: I have no idea how to thorw errors back to Python interface fprintf(stderr, "Your memory is not enough, thus an error has been raised " \ "above. You need to take the following actions:\n" \ "1. Check if the network size of the model is too large.\n" \ @@ -22,8 +22,9 @@ inline void DPAssert(cudaError_t code, const char *file, int line, bool abort=tr "4. Check if another program is using the same GPU by execuating `nvidia-smi`. " \ "The usage of GPUs is controlled by `CUDA_VISIBLE_DEVICES` " \ "environment variable.\n"); + if (abort) throw deepmd::deepmd_exception_oom("CUDA Assert"); } - if (abort) exit(code); + if (abort) throw deepmd::deepmd_exception("CUDA Assert"); } } @@ -34,7 +35,6 @@ inline void nborAssert(cudaError_t code, const char *file, int line, bool abort= fprintf(stderr,"cuda assert: %s %s %d\n", "DeePMD-kit:\tillegal nbor list sorting", file, line); if (code == 2) { // out of memory - // TODO: I have no idea how to thorw errors back to Python interface fprintf(stderr, "Your memory is not enough, thus an error has been raised " \ "above. You need to take the following actions:\n" \ "1. Check if the network size of the model is too large.\n" \ @@ -44,8 +44,9 @@ inline void nborAssert(cudaError_t code, const char *file, int line, bool abort= "4. Check if another program is using the same GPU by execuating `nvidia-smi`. " \ "The usage of GPUs is controlled by `CUDA_VISIBLE_DEVICES` " \ "environment variable.\n"); + if (abort) throw deepmd::deepmd_exception_oom("CUDA Assert"); } - if (abort) exit(code); + if (abort) throw deepmd::deepmd_exception("CUDA Assert"); } } diff --git a/source/lib/include/gpu_rocm.h b/source/lib/include/gpu_rocm.h index 955ffe5bf7..b6439c3bb8 100644 --- a/source/lib/include/gpu_rocm.h +++ b/source/lib/include/gpu_rocm.h @@ -5,6 +5,7 @@ #include //#include //#include +#include "errors.h" #define GPU_MAX_NBOR_SIZE 4096 @@ -12,7 +13,7 @@ inline void DPAssert(hipError_t code, const char *file, int line, bool abort=true) { if (code != hipSuccess) { fprintf(stderr,"hip assert: %s %s %d\n", hipGetErrorString(code), file, line); - if (abort) exit(code); + if (abort) throw deepmd::deepmd_exception("CUDA Assert"); } } @@ -20,7 +21,7 @@ inline void DPAssert(hipError_t code, const char *file, int line, bool abort=tru inline void nborAssert(hipError_t code, const char *file, int line, bool abort=true) { if (code != hipSuccess) { fprintf(stderr,"hip assert: %s %s %d\n", "DeePMD-kit:\tillegal nbor list sorting", file, line); - if (abort) exit(code); + if (abort) throw deepmd::deepmd_exception("CUDA Assert"); } } diff --git a/source/lib/src/fmt_nlist.cc b/source/lib/src/fmt_nlist.cc index add83dadcf..35155d77d1 100644 --- a/source/lib/src/fmt_nlist.cc +++ b/source/lib/src/fmt_nlist.cc @@ -4,6 +4,7 @@ #include "fmt_nlist.h" #include "SimulationRegion.h" #include +#include "errors.h" using namespace deepmd; @@ -185,7 +186,7 @@ format_nlist_cpu ( << fmt_ilist.size() << " which does not match " << nnei << std::endl; - exit(1); + throw deepmd::deepmd_exception(); } std::copy(fmt_ilist.begin(), fmt_ilist.end(), cur_nlist); } diff --git a/source/lib/src/pair_tab.cc b/source/lib/src/pair_tab.cc index 5137e17ac9..2c48ce957a 100644 --- a/source/lib/src/pair_tab.cc +++ b/source/lib/src/pair_tab.cc @@ -3,6 +3,7 @@ #include #include #include "pair_tab.h" +#include "errors.h" inline void _pair_tabulated_inter ( @@ -25,7 +26,7 @@ void _pair_tabulated_inter ( // std::cout << rr << " " << rmin << " " << hh << " " << uu << std::endl; if (uu < 0) { std::cerr << "coord go beyond table lower boundary" << std::endl; - exit(1); + throw deepmd::deepmd_exception(); } int idx = uu; if (idx >= nspline) { diff --git a/source/lib/src/prod_force.cc b/source/lib/src/prod_force.cc index ffe177e16c..e9784d3409 100644 --- a/source/lib/src/prod_force.cc +++ b/source/lib/src/prod_force.cc @@ -1,6 +1,7 @@ #include #include #include "prod_force.h" +#include "errors.h" inline void make_index_range ( @@ -14,7 +15,7 @@ make_index_range ( idx_end = nei_idx * 4 + 4; } else { - throw std::runtime_error("should no reach here"); + throw deepmd::deepmd_exception("should no reach here"); } } diff --git a/source/lib/src/prod_force_grad.cc b/source/lib/src/prod_force_grad.cc index 7872ea5c55..110bf790f4 100644 --- a/source/lib/src/prod_force_grad.cc +++ b/source/lib/src/prod_force_grad.cc @@ -2,6 +2,7 @@ #include #include #include "prod_force_grad.h" +#include "errors.h" inline void make_index_range ( @@ -15,7 +16,7 @@ make_index_range ( idx_end = nei_idx * 4 + 4; } else { - throw std::runtime_error("should no reach here"); + throw deepmd::deepmd_exception("should no reach here"); } } diff --git a/source/lib/src/prod_virial.cc b/source/lib/src/prod_virial.cc index 086bc94245..f1c598c807 100644 --- a/source/lib/src/prod_virial.cc +++ b/source/lib/src/prod_virial.cc @@ -2,6 +2,7 @@ #include #include #include "prod_virial.h" +#include "errors.h" inline void make_index_range ( @@ -15,7 +16,7 @@ make_index_range ( idx_end = nei_idx * 4 + 4; } else { - throw std::runtime_error("should no reach here"); + throw deepmd::deepmd_exception("should no reach here"); } } diff --git a/source/lib/src/prod_virial_grad.cc b/source/lib/src/prod_virial_grad.cc index 59c3192fc0..8e225c0793 100644 --- a/source/lib/src/prod_virial_grad.cc +++ b/source/lib/src/prod_virial_grad.cc @@ -1,6 +1,7 @@ #include #include #include "prod_virial_grad.h" +#include "errors.h" inline void make_index_range ( @@ -14,7 +15,7 @@ make_index_range ( idx_end = nei_idx * 4 + 4; } else { - throw std::runtime_error("should no reach here"); + throw deepmd::deepmd_exception("should no reach here"); } } diff --git a/source/lib/src/region.cc b/source/lib/src/region.cc index 62dcdb9b68..90704016c2 100644 --- a/source/lib/src/region.cc +++ b/source/lib/src/region.cc @@ -1,6 +1,7 @@ #include #include #include "region.h" +#include "errors.h" #define BOXT_DIM 9 using namespace deepmd; @@ -33,7 +34,7 @@ compute_volume(const FPTYPE * boxt) boxt[0*3+1] * (boxt[1*3+0]*boxt[2*3+2] - boxt[2*3+0]*boxt[1*3+2]) + boxt[0*3+2] * (boxt[1*3+0]*boxt[2*3+1] - boxt[2*3+0]*boxt[1*3+1]); if (volume < 0) { - throw std::runtime_error("Negative volume detected. Please make sure the simulation cell obeys the right-hand rule."); + throw deepmd::deepmd_exception("Negative volume detected. Please make sure the simulation cell obeys the right-hand rule."); } return volume; } diff --git a/source/op/descrpt.cc b/source/op/descrpt.cc index 10ba125594..7d1761f253 100644 --- a/source/op/descrpt.cc +++ b/source/op/descrpt.cc @@ -2,6 +2,7 @@ #include "ComputeDescriptor.h" #include "neighbor_list.h" #include "fmt_nlist.h" +#include "errors.h" typedef double boxtensor_t ; typedef double compute_t; @@ -49,6 +50,7 @@ class DescrptOp : public OpKernel { } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor const Tensor& coord_tensor = context->input(0); const Tensor& type_tensor = context->input(1); @@ -105,7 +107,7 @@ class DescrptOp : public OpKernel { nei_mode = -1; } else { - throw std::runtime_error("invalid mesh tensor"); + throw deepmd::deepmd_exception("invalid mesh tensor"); } bool b_pbc = true; // if region is given extended, do not use pbc @@ -254,7 +256,7 @@ class DescrptOp : public OpKernel { ::build_nlist (d_nlist_a, d_nlist_r, d_coord3, rcut_a, rcut_r, NULL); } else { - throw std::runtime_error("unknow neighbor mode"); + throw deepmd::deepmd_exception("unknow neighbor mode"); } // loop over atoms, compute descriptors for each atom @@ -360,6 +362,17 @@ class DescrptOp : public OpKernel { } } } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: float rcut_a; diff --git a/source/op/descrpt_se_a_ef.cc b/source/op/descrpt_se_a_ef.cc index 3ba41624d9..2858ad7d8b 100644 --- a/source/op/descrpt_se_a_ef.cc +++ b/source/op/descrpt_se_a_ef.cc @@ -3,6 +3,7 @@ #include "ComputeDescriptor.h" #include "neighbor_list.h" #include "fmt_nlist.h" +#include "errors.h" typedef double boxtensor_t ; typedef double compute_t; @@ -49,6 +50,7 @@ class DescrptSeAEfOp : public OpKernel { } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& coord_tensor = context->input(context_input_index++); @@ -112,7 +114,7 @@ class DescrptSeAEfOp : public OpKernel { nei_mode = -1; } else { - throw std::runtime_error("invalid mesh tensor"); + throw deepmd::deepmd_exception("invalid mesh tensor"); } bool b_pbc = true; // if region is given extended, do not use pbc @@ -267,7 +269,7 @@ class DescrptSeAEfOp : public OpKernel { ::build_nlist (d_nlist_a, d_nlist_r, d_coord3, rcut_a, rcut_r, NULL); } else { - throw std::runtime_error("unknow neighbor mode"); + throw deepmd::deepmd_exception("unknow neighbor mode"); } // loop over atoms, compute descriptors for each atom @@ -331,6 +333,17 @@ class DescrptSeAEfOp : public OpKernel { } } } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: float rcut_a; diff --git a/source/op/descrpt_se_a_ef_para.cc b/source/op/descrpt_se_a_ef_para.cc index 2cb3b3445c..55df81de6e 100644 --- a/source/op/descrpt_se_a_ef_para.cc +++ b/source/op/descrpt_se_a_ef_para.cc @@ -2,6 +2,7 @@ #include "ComputeDescriptor.h" #include "neighbor_list.h" #include "fmt_nlist.h" +#include "errors.h" typedef double boxtensor_t ; typedef double compute_t; @@ -48,6 +49,7 @@ class DescrptSeAEfParaOp : public OpKernel { } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& coord_tensor = context->input(context_input_index++); @@ -111,7 +113,7 @@ class DescrptSeAEfParaOp : public OpKernel { nei_mode = -1; } else { - throw std::runtime_error("invalid mesh tensor"); + throw deepmd::deepmd_exception("invalid mesh tensor"); } bool b_pbc = true; // if region is given extended, do not use pbc @@ -266,7 +268,7 @@ class DescrptSeAEfParaOp : public OpKernel { ::build_nlist (d_nlist_a, d_nlist_r, d_coord3, rcut_a, rcut_r, NULL); } else { - throw std::runtime_error("unknow neighbor mode"); + throw deepmd::deepmd_exception("unknow neighbor mode"); } // loop over atoms, compute descriptors for each atom @@ -330,6 +332,17 @@ class DescrptSeAEfParaOp : public OpKernel { } } } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: float rcut_a; diff --git a/source/op/descrpt_se_a_ef_vert.cc b/source/op/descrpt_se_a_ef_vert.cc index 615b153bf3..e5a2fbd8a4 100644 --- a/source/op/descrpt_se_a_ef_vert.cc +++ b/source/op/descrpt_se_a_ef_vert.cc @@ -2,6 +2,7 @@ #include "ComputeDescriptor.h" #include "neighbor_list.h" #include "fmt_nlist.h" +#include "errors.h" typedef double boxtensor_t ; typedef double compute_t; @@ -48,6 +49,7 @@ class DescrptSeAEfVertOp : public OpKernel { } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& coord_tensor = context->input(context_input_index++); @@ -111,7 +113,7 @@ class DescrptSeAEfVertOp : public OpKernel { nei_mode = -1; } else { - throw std::runtime_error("invalid mesh tensor"); + throw deepmd::deepmd_exception("invalid mesh tensor"); } bool b_pbc = true; // if region is given extended, do not use pbc @@ -266,7 +268,7 @@ class DescrptSeAEfVertOp : public OpKernel { ::build_nlist (d_nlist_a, d_nlist_r, d_coord3, rcut_a, rcut_r, NULL); } else { - throw std::runtime_error("unknow neighbor mode"); + throw deepmd::deepmd_exception("unknow neighbor mode"); } // loop over atoms, compute descriptors for each atom @@ -330,6 +332,17 @@ class DescrptSeAEfVertOp : public OpKernel { } } } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: float rcut_a; diff --git a/source/op/ewald_recp.cc b/source/op/ewald_recp.cc index 9159dc5931..d46208c451 100644 --- a/source/op/ewald_recp.cc +++ b/source/op/ewald_recp.cc @@ -1,5 +1,6 @@ #include "custom_op.h" #include "ewald.h" +#include "errors.h" typedef double boxtensor_t ; typedef double compute_t; @@ -28,6 +29,7 @@ class EwaldRecpOp : public OpKernel { } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int cc = 0; const Tensor& coord_tensor = context->input(cc++); @@ -115,6 +117,17 @@ class EwaldRecpOp : public OpKernel { virial(kk, ii) = d_virial[ii]; } } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: deepmd::EwaldParameters ep; diff --git a/source/op/gelu_multi_device.cc b/source/op/gelu_multi_device.cc index 508f60ccef..bbeeeffbe3 100644 --- a/source/op/gelu_multi_device.cc +++ b/source/op/gelu_multi_device.cc @@ -1,5 +1,6 @@ #include "custom_op.h" #include "gelu.h" +#include "errors.h" REGISTER_OP("Gelu") .Attr("T: {float, double} = DT_DOUBLE") @@ -26,6 +27,7 @@ class GeluOp : public OpKernel { public : explicit GeluOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { + try { // Grab the input tensor const Tensor& x_tensor = context->input(0); Tensor * output_tensor = NULL; @@ -61,6 +63,17 @@ class GeluOp : public OpKernel { out, x, size); } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private : std::string device; @@ -73,6 +86,7 @@ class GeluGradOp : public OpKernel { public : explicit GeluGradOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { + try { // Grab the input tensor const Tensor& dy_tensor = context->input(0); const Tensor& x_tensor = context->input(1); @@ -110,6 +124,17 @@ class GeluGradOp : public OpKernel { out, x, dy, size); } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private : std::string device; @@ -122,6 +147,7 @@ class GeluGradGradOp : public OpKernel { public : explicit GeluGradGradOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { + try { // Grab the input tensor const Tensor& dy_tensor = context->input(0); const Tensor& dy_2_tensor = context->input(1); @@ -157,6 +183,17 @@ class GeluGradGradOp : public OpKernel { out, x, dy, dy_2, size); } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private : std::string device; diff --git a/source/op/legacy/descrpt_se_a.cc b/source/op/legacy/descrpt_se_a.cc index 51b8e26e0f..5bbfa6d0c9 100644 --- a/source/op/legacy/descrpt_se_a.cc +++ b/source/op/legacy/descrpt_se_a.cc @@ -107,7 +107,7 @@ class DescrptSeAOp : public OpKernel { nei_mode = -1; } else { - throw std::runtime_error("invalid mesh tensor"); + throw deepmd::deepmd_exception("invalid mesh tensor"); } bool b_pbc = true; // if region is given extended, do not use pbc @@ -253,7 +253,7 @@ class DescrptSeAOp : public OpKernel { ::build_nlist (d_nlist_a, d_nlist_r, d_coord3, rcut_a, rcut_r, NULL); } else { - throw std::runtime_error("unknow neighbor mode"); + throw deepmd::deepmd_exception("unknow neighbor mode"); } // loop over atoms, compute descriptors for each atom diff --git a/source/op/legacy/descrpt_se_r.cc b/source/op/legacy/descrpt_se_r.cc index 7031ed20e8..b811431922 100644 --- a/source/op/legacy/descrpt_se_r.cc +++ b/source/op/legacy/descrpt_se_r.cc @@ -99,7 +99,7 @@ class DescrptSeROp : public OpKernel { nei_mode = -1; } else { - throw std::runtime_error("invalid mesh tensor"); + throw deepmd::deepmd_exception("invalid mesh tensor"); } bool b_pbc = true; // if region is given extended, do not use pbc @@ -238,7 +238,7 @@ class DescrptSeROp : public OpKernel { ::build_nlist (d_nlist_null, d_nlist, d_coord3, -1, rcut, NULL); } else { - throw std::runtime_error("unknow neighbor mode"); + throw deepmd::deepmd_exception("unknow neighbor mode"); } // loop over atoms, compute descriptors for each atom diff --git a/source/op/map_aparam.cc b/source/op/map_aparam.cc index f1c98bdc9c..4313619495 100644 --- a/source/op/map_aparam.cc +++ b/source/op/map_aparam.cc @@ -1,5 +1,6 @@ #include "custom_op.h" #include "map_aparam.h" +#include "errors.h" REGISTER_OP("MapAparam") .Attr("T: {float, double} = DT_DOUBLE") @@ -20,6 +21,7 @@ class MapAparamOp : public OpKernel { } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& aparam_tensor = context->input(context_input_index++); @@ -70,6 +72,17 @@ class MapAparamOp : public OpKernel { nnei, numb_aparam); } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: int n_r_sel, n_a_sel, n_a_shift; diff --git a/source/op/neighbor_stat.cc b/source/op/neighbor_stat.cc index 11f991b4b7..55b0800a64 100644 --- a/source/op/neighbor_stat.cc +++ b/source/op/neighbor_stat.cc @@ -1,5 +1,6 @@ #include "custom_op.h" #include "neighbor_list.h" +#include "errors.h" typedef double boxtensor_t ; typedef double compute_t; @@ -23,6 +24,7 @@ class NeighborStatOp : public OpKernel { } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& coord_tensor = context->input(context_input_index++); @@ -60,7 +62,7 @@ class NeighborStatOp : public OpKernel { nei_mode = -1; } else { - throw std::runtime_error("invalid mesh tensor"); + throw deepmd::deepmd_exception("invalid mesh tensor"); } // if region is given extended, do not use pbc bool b_pbc = (nei_mode >= 1 || nei_mode == -1) ? false : true; @@ -139,7 +141,7 @@ class NeighborStatOp : public OpKernel { ::build_nlist (d_nlist_a, d_nlist_r, d_coord3, -1, rcut, NULL); } else { - throw std::runtime_error("unknow neighbor mode"); + throw deepmd::deepmd_exception("unknow neighbor mode"); } int MAX_NNEI = 0; @@ -167,6 +169,17 @@ class NeighborStatOp : public OpKernel { min_nbor_dist[ii * MAX_NNEI + jj] = sqrt(rij[0] * rij[0] + rij[1] * rij[1] + rij[2] * rij[2]); } } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: diff --git a/source/op/prod_env_mat_multi_device.cc b/source/op/prod_env_mat_multi_device.cc index 22fb223289..43235ae488 100644 --- a/source/op/prod_env_mat_multi_device.cc +++ b/source/op/prod_env_mat_multi_device.cc @@ -4,6 +4,7 @@ #include "region.h" #include "neighbor_list.h" #include "prod_env_mat.h" +#include "errors.h" REGISTER_OP("ProdEnvMatA") .Attr("T: {float, double} = DT_DOUBLE") @@ -321,6 +322,7 @@ class ProdEnvMatAOp : public OpKernel { } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& coord_tensor = context->input(context_input_index++); @@ -382,7 +384,7 @@ class ProdEnvMatAOp : public OpKernel { nei_mode = -1; } else { - throw std::runtime_error("invalid mesh tensor"); + throw deepmd::deepmd_exception("invalid mesh tensor"); } // Create output tensors @@ -542,6 +544,19 @@ class ProdEnvMatAOp : public OpKernel { if(b_nlist_map) _map_nlist_cpu(nlist, &idx_mapping[0], nloc, nnei); } } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + // Ref: + // https://github.com/tensorflow/tensorflow/blob/5dcfc51118817f27fad5246812d83e5dccdc5f72/tensorflow/core/kernels/mkl/mkl_tfconv_op.h#L120-L126 + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } ///////////////////////////////////////////////////////////////////////////////////////////// @@ -584,6 +599,7 @@ class ProdEnvMatROp : public OpKernel { } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& coord_tensor = context->input(context_input_index++); @@ -642,7 +658,7 @@ class ProdEnvMatROp : public OpKernel { nei_mode = -1; } else { - throw std::runtime_error("invalid mesh tensor"); + throw deepmd::deepmd_exception("invalid mesh tensor"); } // Create an output tensor @@ -802,6 +818,17 @@ class ProdEnvMatROp : public OpKernel { if(b_nlist_map) _map_nlist_cpu(nlist, &idx_mapping[0], nloc, nnei); } } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } ///////////////////////////////////////////////////////////////////////////////////////////// diff --git a/source/op/prod_force.cc b/source/op/prod_force.cc index 307d00a85d..427624691a 100644 --- a/source/op/prod_force.cc +++ b/source/op/prod_force.cc @@ -1,4 +1,5 @@ #include "custom_op.h" +#include "errors.h" REGISTER_OP("ProdForce") .Attr("T: {float, double} = DT_DOUBLE") @@ -26,6 +27,7 @@ class ProdForceOp : public OpKernel { } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor const Tensor& net_deriv_tensor = context->input(0); const Tensor& in_deriv_tensor = context->input(1); @@ -139,6 +141,17 @@ class ProdForceOp : public OpKernel { } } } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: int n_r_sel, n_a_sel, n_a_shift; diff --git a/source/op/prod_force_grad.cc b/source/op/prod_force_grad.cc index 52c8ed845f..b49a50be53 100644 --- a/source/op/prod_force_grad.cc +++ b/source/op/prod_force_grad.cc @@ -1,4 +1,5 @@ #include "custom_op.h" +#include "errors.h" REGISTER_OP("ProdForceGrad") .Attr("T: {float, double} = DT_DOUBLE") @@ -25,6 +26,7 @@ class ProdForceGradOp : public OpKernel } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor const Tensor& grad_tensor = context->input(0); const Tensor& net_deriv_tensor = context->input(1); @@ -151,6 +153,17 @@ class ProdForceGradOp : public OpKernel } } } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: int n_r_sel, n_a_sel, n_a_shift; diff --git a/source/op/prod_force_grad_multi_device.cc b/source/op/prod_force_grad_multi_device.cc index 5aff4bbbef..e29ab01a5e 100644 --- a/source/op/prod_force_grad_multi_device.cc +++ b/source/op/prod_force_grad_multi_device.cc @@ -1,5 +1,6 @@ #include "custom_op.h" #include "prod_force_grad.h" +#include "errors.h" REGISTER_OP("ProdForceSeAGrad") .Attr("T: {float, double} = DT_DOUBLE") @@ -31,6 +32,7 @@ class ProdForceSeAGradOp : public OpKernel { } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& grad_tensor = context->input(context_input_index++); @@ -126,6 +128,17 @@ class ProdForceSeAGradOp : public OpKernel { grad, in_deriv, nlist, nloc, nnei); } } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: std::string device; @@ -139,6 +152,7 @@ class ProdForceSeRGradOp : public OpKernel explicit ProdForceSeRGradOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& grad_tensor = context->input(context_input_index++); @@ -234,6 +248,17 @@ class ProdForceSeRGradOp : public OpKernel grad, in_deriv, nlist, nloc, nnei); } } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: std::string device; diff --git a/source/op/prod_force_multi_device.cc b/source/op/prod_force_multi_device.cc index 63e6945906..e08e3a4dc1 100644 --- a/source/op/prod_force_multi_device.cc +++ b/source/op/prod_force_multi_device.cc @@ -1,5 +1,6 @@ #include "custom_op.h" #include "prod_force.h" +#include "errors.h" REGISTER_OP("ProdForceSeA") .Attr("T: {float, double} = DT_DOUBLE") @@ -25,6 +26,7 @@ class ProdForceSeAOp : public OpKernel { explicit ProdForceSeAOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& net_deriv_tensor = context->input(context_input_index++); @@ -101,6 +103,17 @@ class ProdForceSeAOp : public OpKernel { net_deriv, in_deriv, nlist, nloc, nall, nnei); } } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: std::string device; @@ -111,6 +124,7 @@ class ProdForceSeROp : public OpKernel { public: explicit ProdForceSeROp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& net_deriv_tensor = context->input(context_input_index++); @@ -186,6 +200,17 @@ class ProdForceSeROp : public OpKernel { net_deriv, in_deriv, nlist, nloc, nall, nnei); } } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: std::string device; diff --git a/source/op/prod_force_se_a_grad.cc b/source/op/prod_force_se_a_grad.cc index 7617c244ed..0dbc331437 100644 --- a/source/op/prod_force_se_a_grad.cc +++ b/source/op/prod_force_se_a_grad.cc @@ -1,5 +1,6 @@ #include "custom_op.h" #include "prod_force_grad.h" +#include "errors.h" REGISTER_OP("ProdForceSeAGrad") .Attr("T: {float, double} = DT_DOUBLE") @@ -25,6 +26,7 @@ class ProdForceSeAGradOp : public OpKernel } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& grad_tensor = context->input(context_input_index++); @@ -95,6 +97,17 @@ class ProdForceSeAGradOp : public OpKernel nloc, nnei); } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: int n_r_sel, n_a_sel, n_a_shift; diff --git a/source/op/prod_force_se_r_grad.cc b/source/op/prod_force_se_r_grad.cc index 9fff3724ed..0e4a53f113 100644 --- a/source/op/prod_force_se_r_grad.cc +++ b/source/op/prod_force_se_r_grad.cc @@ -1,5 +1,6 @@ #include "custom_op.h" #include "prod_force_grad.h" +#include "errors.h" REGISTER_OP("ProdForceSeRGrad") .Attr("T: {float, double} = DT_DOUBLE") @@ -20,6 +21,7 @@ class ProdForceSeRGradOp : public OpKernel } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& grad_tensor = context->input(context_input_index++); @@ -89,6 +91,17 @@ class ProdForceSeRGradOp : public OpKernel nloc, nnei); } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } }; diff --git a/source/op/prod_virial.cc b/source/op/prod_virial.cc index d83ab27225..0e3cecd258 100644 --- a/source/op/prod_virial.cc +++ b/source/op/prod_virial.cc @@ -1,4 +1,5 @@ #include "custom_op.h" +#include "errors.h" REGISTER_OP("ProdVirial") .Attr("T: {float, double} = DT_DOUBLE") @@ -28,6 +29,7 @@ class ProdVirialOp : public OpKernel { } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor const Tensor& net_deriv_tensor = context->input(0); const Tensor& in_deriv_tensor = context->input(1); @@ -160,6 +162,17 @@ class ProdVirialOp : public OpKernel { } } } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: int n_r_sel, n_a_sel, n_a_shift; diff --git a/source/op/prod_virial_grad.cc b/source/op/prod_virial_grad.cc index d07a661cb9..fac919872e 100644 --- a/source/op/prod_virial_grad.cc +++ b/source/op/prod_virial_grad.cc @@ -1,4 +1,5 @@ #include "custom_op.h" +#include "errors.h" REGISTER_OP("ProdVirialGrad") .Attr("T: {float, double} = DT_DOUBLE") @@ -26,6 +27,7 @@ class ProdVirialGradOp : public OpKernel } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor const Tensor& grad_tensor = context->input(0); const Tensor& net_deriv_tensor = context->input(1); @@ -160,6 +162,17 @@ class ProdVirialGradOp : public OpKernel } } } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: int n_r_sel, n_a_sel, n_a_shift; diff --git a/source/op/prod_virial_grad_multi_device.cc b/source/op/prod_virial_grad_multi_device.cc index 7a37da9b38..3bd94864c1 100644 --- a/source/op/prod_virial_grad_multi_device.cc +++ b/source/op/prod_virial_grad_multi_device.cc @@ -1,5 +1,6 @@ #include "custom_op.h" #include "prod_virial_grad.h" +#include "errors.h" REGISTER_OP("ProdVirialSeAGrad") .Attr("T: {float, double} = DT_DOUBLE") @@ -34,6 +35,7 @@ class ProdVirialSeAGradOp : public OpKernel } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& grad_tensor = context->input(context_input_index++); @@ -140,6 +142,17 @@ class ProdVirialSeAGradOp : public OpKernel grad, in_deriv, rij, nlist, nloc, nnei); } } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: std::string device; @@ -153,6 +166,7 @@ class ProdVirialSeRGradOp : public OpKernel explicit ProdVirialSeRGradOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& grad_tensor = context->input(context_input_index++); @@ -258,6 +272,17 @@ class ProdVirialSeRGradOp : public OpKernel grad, in_deriv, rij, nlist, nloc, nnei); } } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: std::string device; diff --git a/source/op/prod_virial_multi_device.cc b/source/op/prod_virial_multi_device.cc index 02c212a2d9..b3e72e0239 100644 --- a/source/op/prod_virial_multi_device.cc +++ b/source/op/prod_virial_multi_device.cc @@ -1,5 +1,6 @@ #include "custom_op.h" #include "prod_virial.h" +#include "errors.h" REGISTER_OP("ProdVirialSeA") .Attr("T: {float, double} = DT_DOUBLE") @@ -28,6 +29,7 @@ class ProdVirialSeAOp : public OpKernel { public: explicit ProdVirialSeAOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& net_deriv_tensor = context->input(context_input_index++); @@ -110,6 +112,17 @@ class ProdVirialSeAOp : public OpKernel { net_deriv, in_deriv, rij, nlist, nloc, nall, nnei); } } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: std::string device; @@ -120,6 +133,7 @@ class ProdVirialSeROp : public OpKernel { public: explicit ProdVirialSeROp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& net_deriv_tensor = context->input(context_input_index++); @@ -202,6 +216,17 @@ class ProdVirialSeROp : public OpKernel { net_deriv, in_deriv, rij, nlist, nloc, nall, nnei); } } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: std::string device; diff --git a/source/op/prod_virial_se_a_grad.cc b/source/op/prod_virial_se_a_grad.cc index cb76d29512..8adefbf5db 100644 --- a/source/op/prod_virial_se_a_grad.cc +++ b/source/op/prod_virial_se_a_grad.cc @@ -1,5 +1,6 @@ #include "custom_op.h" #include "prod_virial_grad.h" +#include "errors.h" REGISTER_OP("ProdVirialSeAGrad") .Attr("T: {float, double} = DT_DOUBLE") @@ -26,6 +27,7 @@ class ProdVirialSeAGradOp : public OpKernel } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& grad_tensor = context->input(context_input_index++); @@ -104,6 +106,17 @@ class ProdVirialSeAGradOp : public OpKernel nloc, nnei); } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: int n_r_sel, n_a_sel, n_a_shift; diff --git a/source/op/prod_virial_se_r_grad.cc b/source/op/prod_virial_se_r_grad.cc index 247f2ee909..7966a21a91 100644 --- a/source/op/prod_virial_se_r_grad.cc +++ b/source/op/prod_virial_se_r_grad.cc @@ -1,5 +1,6 @@ #include "custom_op.h" #include "prod_virial_grad.h" +#include "errors.h" REGISTER_OP("ProdVirialSeRGrad") .Attr("T: {float, double} = DT_DOUBLE") @@ -21,6 +22,7 @@ class ProdVirialSeRGradOp : public OpKernel } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& grad_tensor = context->input(context_input_index++); @@ -98,6 +100,17 @@ class ProdVirialSeRGradOp : public OpKernel nloc, nnei); } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } }; diff --git a/source/op/soft_min.cc b/source/op/soft_min.cc index c30d9c409a..18fe2aab84 100644 --- a/source/op/soft_min.cc +++ b/source/op/soft_min.cc @@ -1,6 +1,7 @@ #include "custom_op.h" #include "ComputeDescriptor.h" #include "soft_min_switch.h" +#include "errors.h" REGISTER_OP("SoftMinSwitch") .Attr("T: {float, double} = DT_DOUBLE") @@ -37,6 +38,7 @@ class SoftMinSwitchOp : public OpKernel { } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int tmp_idx = 0; const Tensor& type_tensor = context->input(tmp_idx++); @@ -102,6 +104,17 @@ class SoftMinSwitchOp : public OpKernel { rmin, rmax); } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: std::vector sel_r; diff --git a/source/op/soft_min_force.cc b/source/op/soft_min_force.cc index 7d09da6613..2d944018c2 100644 --- a/source/op/soft_min_force.cc +++ b/source/op/soft_min_force.cc @@ -1,5 +1,6 @@ #include "custom_op.h" #include "soft_min_switch_force.h" +#include "errors.h" REGISTER_OP("SoftMinForce") .Attr("T: {float, double} = DT_DOUBLE") @@ -24,6 +25,7 @@ class SoftMinForceOp : public OpKernel { } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor const Tensor& du_tensor = context->input(0); const Tensor& sw_deriv_tensor = context->input(1); @@ -77,6 +79,17 @@ class SoftMinForceOp : public OpKernel { nall, nnei); } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: int n_r_sel, n_a_sel; diff --git a/source/op/soft_min_force_grad.cc b/source/op/soft_min_force_grad.cc index a7328734b6..c7b882e0f0 100644 --- a/source/op/soft_min_force_grad.cc +++ b/source/op/soft_min_force_grad.cc @@ -1,5 +1,6 @@ #include "custom_op.h" #include "soft_min_switch_force_grad.h" +#include "errors.h" REGISTER_OP("SoftMinForceGrad") .Attr("T: {float, double} = DT_DOUBLE") @@ -24,6 +25,7 @@ class SoftMinForceGradOp : public OpKernel } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& grad_tensor = context->input(context_input_index++); @@ -88,6 +90,17 @@ class SoftMinForceGradOp : public OpKernel nloc, nnei); } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: int n_r_sel, n_a_sel; diff --git a/source/op/soft_min_virial.cc b/source/op/soft_min_virial.cc index 3273160fe3..c4a830c7f1 100644 --- a/source/op/soft_min_virial.cc +++ b/source/op/soft_min_virial.cc @@ -1,5 +1,6 @@ #include "custom_op.h" #include "soft_min_switch_virial.h" +#include "errors.h" REGISTER_OP("SoftMinVirial") .Attr("T: {float, double} = DT_DOUBLE") @@ -26,6 +27,7 @@ class SoftMinVirialOp : public OpKernel { } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& du_tensor = context->input(context_input_index++); @@ -93,6 +95,17 @@ class SoftMinVirialOp : public OpKernel { nall, nnei); } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: int n_r_sel, n_a_sel; diff --git a/source/op/soft_min_virial_grad.cc b/source/op/soft_min_virial_grad.cc index 034aeb7a09..b73ed81d65 100644 --- a/source/op/soft_min_virial_grad.cc +++ b/source/op/soft_min_virial_grad.cc @@ -1,5 +1,6 @@ #include "custom_op.h" #include "soft_min_switch_virial_grad.h" +#include "errors.h" REGISTER_OP("SoftMinVirialGrad") .Attr("T: {float, double} = DT_DOUBLE") @@ -25,6 +26,7 @@ class SoftMinVirialGradOp : public OpKernel } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& grad_tensor = context->input(context_input_index++); @@ -96,6 +98,17 @@ class SoftMinVirialGradOp : public OpKernel nloc, nnei); } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: int n_r_sel, n_a_sel, n_a_shift; diff --git a/source/op/tabulate_multi_device.cc b/source/op/tabulate_multi_device.cc index 6fafa5698e..11b5b15250 100644 --- a/source/op/tabulate_multi_device.cc +++ b/source/op/tabulate_multi_device.cc @@ -1,5 +1,6 @@ #include "custom_op.h" #include "tabulate.h" +#include "errors.h" REGISTER_OP("TabulateFusion") .Attr("T: {float, double} = DT_DOUBLE") @@ -28,6 +29,7 @@ class TabulateFusionOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("last_layer_size", &last_layer_size)); } void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& table_tensor = context->input(context_input_index++); @@ -79,6 +81,17 @@ class TabulateFusionOp : public OpKernel { descriptor, table, table_info, em_x, em, nloc, nnei, last_layer_size); } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: int last_layer_size; @@ -90,6 +103,7 @@ class TabulateFusionGradOp : public OpKernel { public: explicit TabulateFusionGradOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& table_tensor = context->input(context_input_index++); @@ -147,6 +161,17 @@ class TabulateFusionGradOp : public OpKernel { dy_dem_x, dy_dem, table, table_info, em_x, em, dy, nloc, nnei, last_layer_size); } + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: std::string device; diff --git a/source/op/unaggregated_grad.cc b/source/op/unaggregated_grad.cc index 56502efc55..44bf9620a8 100644 --- a/source/op/unaggregated_grad.cc +++ b/source/op/unaggregated_grad.cc @@ -1,6 +1,7 @@ #include "custom_op.h" #include "ComputeDescriptor.h" #include "neighbor_list.h" +#include "errors.h" REGISTER_OP("UnaggregatedDyDxS") .Attr("T: {float, double} = DT_DOUBLE") @@ -136,6 +137,7 @@ class UnaggregatedDyDxSOp : public OpKernel { explicit UnaggregatedDyDxSOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& y = context->input(context_input_index++); @@ -159,6 +161,17 @@ class UnaggregatedDyDxSOp : public OpKernel { y.shape().dim_size(1), dy_dx->flat().data() ); + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: }; @@ -169,6 +182,7 @@ class UnaggregatedDy2DxSOp : public OpKernel { explicit UnaggregatedDy2DxSOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& y = context->input(context_input_index++); @@ -195,6 +209,17 @@ class UnaggregatedDy2DxSOp : public OpKernel { y.shape().dim_size(1), dy2_dx->flat().data() ); + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: }; @@ -205,6 +230,7 @@ class UnaggregatedDyDxOp : public OpKernel { explicit UnaggregatedDyDxOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& z = context->input(context_input_index++); @@ -232,6 +258,17 @@ class UnaggregatedDyDxOp : public OpKernel { w.shape().dim_size(0), dz_dx->flat().data() ); + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: }; @@ -242,6 +279,7 @@ class UnaggregatedDy2DxOp : public OpKernel { explicit UnaggregatedDy2DxOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { + try { // Grab the input tensor int context_input_index = 0; const Tensor& z = context->input(context_input_index++); @@ -275,6 +313,17 @@ class UnaggregatedDy2DxOp : public OpKernel { w.shape().dim_size(0), dz2_dx->flat().data() ); + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } } private: }; From 46ee28b66b87cc67881796f85b0d5d11b25c6674 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 4 Aug 2021 05:21:37 -0400 Subject: [PATCH 2/6] define try catch function --- source/op/CMakeLists.txt | 8 +-- source/op/custom_op.cc | 20 +++++++ source/op/custom_op.h | 5 ++ source/op/descrpt.cc | 21 ++----- source/op/descrpt_se_a_ef.cc | 21 ++----- source/op/descrpt_se_a_ef_para.cc | 21 ++----- source/op/descrpt_se_a_ef_vert.cc | 21 ++----- source/op/ewald_recp.cc | 17 ++---- source/op/gelu_multi_device.cc | 49 ++++------------ source/op/map_aparam.cc | 17 ++---- source/op/neighbor_stat.cc | 21 ++----- source/op/pair_tab.cc | 4 ++ source/op/prod_env_mat_multi_device.cc | 39 ++++--------- source/op/prod_force.cc | 17 ++---- source/op/prod_force_grad.cc | 17 ++---- source/op/prod_force_grad_multi_device.cc | 33 +++-------- source/op/prod_force_multi_device.cc | 29 ++-------- source/op/prod_force_se_a_grad.cc | 17 ++---- source/op/prod_force_se_r_grad.cc | 17 ++---- source/op/prod_virial.cc | 17 ++---- source/op/prod_virial_grad.cc | 17 ++---- source/op/prod_virial_grad_multi_device.cc | 33 +++-------- source/op/prod_virial_multi_device.cc | 33 +++-------- source/op/prod_virial_se_a_grad.cc | 17 ++---- source/op/prod_virial_se_r_grad.cc | 17 ++---- source/op/soft_min.cc | 17 ++---- source/op/soft_min_force.cc | 17 ++---- source/op/soft_min_force_grad.cc | 17 ++---- source/op/soft_min_virial.cc | 17 ++---- source/op/soft_min_virial_grad.cc | 17 ++---- source/op/tabulate_multi_device.cc | 33 +++-------- source/op/unaggregated_grad.cc | 65 ++++++---------------- 32 files changed, 197 insertions(+), 514 deletions(-) create mode 100644 source/op/custom_op.cc diff --git a/source/op/CMakeLists.txt b/source/op/CMakeLists.txt index 340c5601fb..1075847953 100644 --- a/source/op/CMakeLists.txt +++ b/source/op/CMakeLists.txt @@ -3,10 +3,10 @@ set(OP_LIB ${PROJECT_SOURCE_DIR}/lib/src/SimulationRegion.cpp ${PROJECT_SOURCE_DIR}/lib/src/neighbor_list.cc) set (OP_CXX_FLAG -D_GLIBCXX_USE_CXX11_ABI=${OP_CXX_ABI} ) -file(GLOB OP_SRC prod_force.cc prod_virial.cc descrpt.cc descrpt_se_a_ef.cc descrpt_se_a_ef.cc descrpt_se_a_ef_para.cc descrpt_se_a_ef_vert.cc pair_tab.cc prod_force_multi_device.cc prod_virial_multi_device.cc soft_min.cc soft_min_force.cc soft_min_virial.cc ewald_recp.cc gelu_multi_device.cc map_aparam.cc neighbor_stat.cc unaggregated_grad.cc tabulate_multi_device.cc prod_env_mat_multi_device.cc) -file(GLOB OP_CUDA_SRC prod_force.cc prod_virial.cc descrpt.cc prod_env_mat_multi_device.cc pair_tab.cc prod_force_multi_device.cc prod_virial_multi_device.cc soft_min.cc soft_min_force.cc soft_min_virial.cc gelu_multi_device.cc tabulate_multi_device.cc) -file(GLOB OP_ROCM_SRC prod_force.cc prod_virial.cc descrpt.cc prod_env_mat_multi_device.cc pair_tab.cc prod_force_multi_device.cc prod_virial_multi_device.cc soft_min.cc soft_min_force.cc soft_min_virial.cc gelu_multi_device.cc tabulate_multi_device.cc) -file(GLOB OP_GRADS_SRC prod_force_grad.cc prod_force_grad_multi_device.cc prod_virial_grad.cc prod_virial_grad_multi_device.cc soft_min_force_grad.cc soft_min_virial_grad.cc ) +file(GLOB OP_SRC custom_op.cc prod_force.cc prod_virial.cc descrpt.cc descrpt_se_a_ef.cc descrpt_se_a_ef.cc descrpt_se_a_ef_para.cc descrpt_se_a_ef_vert.cc pair_tab.cc prod_force_multi_device.cc prod_virial_multi_device.cc soft_min.cc soft_min_force.cc soft_min_virial.cc ewald_recp.cc gelu_multi_device.cc map_aparam.cc neighbor_stat.cc unaggregated_grad.cc tabulate_multi_device.cc prod_env_mat_multi_device.cc) +file(GLOB OP_CUDA_SRC custom_op.cc prod_force.cc prod_virial.cc descrpt.cc prod_env_mat_multi_device.cc pair_tab.cc prod_force_multi_device.cc prod_virial_multi_device.cc soft_min.cc soft_min_force.cc soft_min_virial.cc gelu_multi_device.cc tabulate_multi_device.cc) +file(GLOB OP_ROCM_SRC custom_op.cc prod_force.cc prod_virial.cc descrpt.cc prod_env_mat_multi_device.cc pair_tab.cc prod_force_multi_device.cc prod_virial_multi_device.cc soft_min.cc soft_min_force.cc soft_min_virial.cc gelu_multi_device.cc tabulate_multi_device.cc) +file(GLOB OP_GRADS_SRC custom_op.cc prod_force_grad.cc prod_force_grad_multi_device.cc prod_virial_grad.cc prod_virial_grad_multi_device.cc soft_min_force_grad.cc soft_min_virial_grad.cc ) file(GLOB OP_PY *.py) if (BUILD_CPP_IF) diff --git a/source/op/custom_op.cc b/source/op/custom_op.cc new file mode 100644 index 0000000000..80df56604b --- /dev/null +++ b/source/op/custom_op.cc @@ -0,0 +1,20 @@ +#include "custom_op.h" +#include "errors.h" + +namespace deepmd { + void save_compute(OpKernelContext* context, std::function ff) { + try{ + ff(context); + } catch (deepmd::deepmd_exception_oom& e){ + OP_REQUIRES_OK( + context, + errors::ResourceExhausted("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } catch (deepmd::deepmd_exception& e) { + OP_REQUIRES_OK( + context, + errors::Internal("Operation received an exception: ", e.what(), + ", in file ",__FILE__, ":", __LINE__)); + } + } +}; \ No newline at end of file diff --git a/source/op/custom_op.h b/source/op/custom_op.h index e4f9211e61..ce78449e77 100644 --- a/source/op/custom_op.h +++ b/source/op/custom_op.h @@ -26,4 +26,9 @@ struct DeviceFunctor { device = "GPU"; } #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +}; + +namespace deepmd { + typedef void ComputeFunction(OpKernelContext*); + void save_compute(OpKernelContext* context, std::function ff); }; \ No newline at end of file diff --git a/source/op/descrpt.cc b/source/op/descrpt.cc index 7d1761f253..df3efd5246 100644 --- a/source/op/descrpt.cc +++ b/source/op/descrpt.cc @@ -2,7 +2,6 @@ #include "ComputeDescriptor.h" #include "neighbor_list.h" #include "fmt_nlist.h" -#include "errors.h" typedef double boxtensor_t ; typedef double compute_t; @@ -50,7 +49,10 @@ class DescrptOp : public OpKernel { } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor const Tensor& coord_tensor = context->input(0); const Tensor& type_tensor = context->input(1); @@ -107,7 +109,7 @@ class DescrptOp : public OpKernel { nei_mode = -1; } else { - throw deepmd::deepmd_exception("invalid mesh tensor"); + throw std::runtime_error("invalid mesh tensor"); } bool b_pbc = true; // if region is given extended, do not use pbc @@ -256,7 +258,7 @@ class DescrptOp : public OpKernel { ::build_nlist (d_nlist_a, d_nlist_r, d_coord3, rcut_a, rcut_r, NULL); } else { - throw deepmd::deepmd_exception("unknow neighbor mode"); + throw std::runtime_error("unknow neighbor mode"); } // loop over atoms, compute descriptors for each atom @@ -362,17 +364,6 @@ class DescrptOp : public OpKernel { } } } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: float rcut_a; diff --git a/source/op/descrpt_se_a_ef.cc b/source/op/descrpt_se_a_ef.cc index 2858ad7d8b..8edfcb45dc 100644 --- a/source/op/descrpt_se_a_ef.cc +++ b/source/op/descrpt_se_a_ef.cc @@ -3,7 +3,6 @@ #include "ComputeDescriptor.h" #include "neighbor_list.h" #include "fmt_nlist.h" -#include "errors.h" typedef double boxtensor_t ; typedef double compute_t; @@ -50,7 +49,10 @@ class DescrptSeAEfOp : public OpKernel { } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& coord_tensor = context->input(context_input_index++); @@ -114,7 +116,7 @@ class DescrptSeAEfOp : public OpKernel { nei_mode = -1; } else { - throw deepmd::deepmd_exception("invalid mesh tensor"); + throw std::runtime_error("invalid mesh tensor"); } bool b_pbc = true; // if region is given extended, do not use pbc @@ -269,7 +271,7 @@ class DescrptSeAEfOp : public OpKernel { ::build_nlist (d_nlist_a, d_nlist_r, d_coord3, rcut_a, rcut_r, NULL); } else { - throw deepmd::deepmd_exception("unknow neighbor mode"); + throw std::runtime_error("unknow neighbor mode"); } // loop over atoms, compute descriptors for each atom @@ -333,17 +335,6 @@ class DescrptSeAEfOp : public OpKernel { } } } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: float rcut_a; diff --git a/source/op/descrpt_se_a_ef_para.cc b/source/op/descrpt_se_a_ef_para.cc index 55df81de6e..65bd3a9269 100644 --- a/source/op/descrpt_se_a_ef_para.cc +++ b/source/op/descrpt_se_a_ef_para.cc @@ -2,7 +2,6 @@ #include "ComputeDescriptor.h" #include "neighbor_list.h" #include "fmt_nlist.h" -#include "errors.h" typedef double boxtensor_t ; typedef double compute_t; @@ -49,7 +48,10 @@ class DescrptSeAEfParaOp : public OpKernel { } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& coord_tensor = context->input(context_input_index++); @@ -113,7 +115,7 @@ class DescrptSeAEfParaOp : public OpKernel { nei_mode = -1; } else { - throw deepmd::deepmd_exception("invalid mesh tensor"); + throw std::runtime_error("invalid mesh tensor"); } bool b_pbc = true; // if region is given extended, do not use pbc @@ -268,7 +270,7 @@ class DescrptSeAEfParaOp : public OpKernel { ::build_nlist (d_nlist_a, d_nlist_r, d_coord3, rcut_a, rcut_r, NULL); } else { - throw deepmd::deepmd_exception("unknow neighbor mode"); + throw std::runtime_error("unknow neighbor mode"); } // loop over atoms, compute descriptors for each atom @@ -332,17 +334,6 @@ class DescrptSeAEfParaOp : public OpKernel { } } } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: float rcut_a; diff --git a/source/op/descrpt_se_a_ef_vert.cc b/source/op/descrpt_se_a_ef_vert.cc index e5a2fbd8a4..1c898886c7 100644 --- a/source/op/descrpt_se_a_ef_vert.cc +++ b/source/op/descrpt_se_a_ef_vert.cc @@ -2,7 +2,6 @@ #include "ComputeDescriptor.h" #include "neighbor_list.h" #include "fmt_nlist.h" -#include "errors.h" typedef double boxtensor_t ; typedef double compute_t; @@ -49,7 +48,10 @@ class DescrptSeAEfVertOp : public OpKernel { } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& coord_tensor = context->input(context_input_index++); @@ -113,7 +115,7 @@ class DescrptSeAEfVertOp : public OpKernel { nei_mode = -1; } else { - throw deepmd::deepmd_exception("invalid mesh tensor"); + throw std::runtime_error("invalid mesh tensor"); } bool b_pbc = true; // if region is given extended, do not use pbc @@ -268,7 +270,7 @@ class DescrptSeAEfVertOp : public OpKernel { ::build_nlist (d_nlist_a, d_nlist_r, d_coord3, rcut_a, rcut_r, NULL); } else { - throw deepmd::deepmd_exception("unknow neighbor mode"); + throw std::runtime_error("unknow neighbor mode"); } // loop over atoms, compute descriptors for each atom @@ -332,17 +334,6 @@ class DescrptSeAEfVertOp : public OpKernel { } } } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: float rcut_a; diff --git a/source/op/ewald_recp.cc b/source/op/ewald_recp.cc index d46208c451..53c5fd1470 100644 --- a/source/op/ewald_recp.cc +++ b/source/op/ewald_recp.cc @@ -1,6 +1,5 @@ #include "custom_op.h" #include "ewald.h" -#include "errors.h" typedef double boxtensor_t ; typedef double compute_t; @@ -29,7 +28,10 @@ class EwaldRecpOp : public OpKernel { } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int cc = 0; const Tensor& coord_tensor = context->input(cc++); @@ -117,17 +119,6 @@ class EwaldRecpOp : public OpKernel { virial(kk, ii) = d_virial[ii]; } } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: deepmd::EwaldParameters ep; diff --git a/source/op/gelu_multi_device.cc b/source/op/gelu_multi_device.cc index bbeeeffbe3..767d864694 100644 --- a/source/op/gelu_multi_device.cc +++ b/source/op/gelu_multi_device.cc @@ -1,6 +1,5 @@ #include "custom_op.h" #include "gelu.h" -#include "errors.h" REGISTER_OP("Gelu") .Attr("T: {float, double} = DT_DOUBLE") @@ -27,7 +26,10 @@ class GeluOp : public OpKernel { public : explicit GeluOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor const Tensor& x_tensor = context->input(0); Tensor * output_tensor = NULL; @@ -63,17 +65,6 @@ class GeluOp : public OpKernel { out, x, size); } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private : std::string device; @@ -86,7 +77,10 @@ class GeluGradOp : public OpKernel { public : explicit GeluGradOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor const Tensor& dy_tensor = context->input(0); const Tensor& x_tensor = context->input(1); @@ -124,17 +118,6 @@ class GeluGradOp : public OpKernel { out, x, dy, size); } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private : std::string device; @@ -147,7 +130,10 @@ class GeluGradGradOp : public OpKernel { public : explicit GeluGradGradOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor const Tensor& dy_tensor = context->input(0); const Tensor& dy_2_tensor = context->input(1); @@ -183,17 +169,6 @@ class GeluGradGradOp : public OpKernel { out, x, dy, dy_2, size); } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private : std::string device; diff --git a/source/op/map_aparam.cc b/source/op/map_aparam.cc index 4313619495..c0b65a28e6 100644 --- a/source/op/map_aparam.cc +++ b/source/op/map_aparam.cc @@ -1,6 +1,5 @@ #include "custom_op.h" #include "map_aparam.h" -#include "errors.h" REGISTER_OP("MapAparam") .Attr("T: {float, double} = DT_DOUBLE") @@ -21,7 +20,10 @@ class MapAparamOp : public OpKernel { } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& aparam_tensor = context->input(context_input_index++); @@ -72,17 +74,6 @@ class MapAparamOp : public OpKernel { nnei, numb_aparam); } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: int n_r_sel, n_a_sel, n_a_shift; diff --git a/source/op/neighbor_stat.cc b/source/op/neighbor_stat.cc index 55b0800a64..8aab137892 100644 --- a/source/op/neighbor_stat.cc +++ b/source/op/neighbor_stat.cc @@ -1,6 +1,5 @@ #include "custom_op.h" #include "neighbor_list.h" -#include "errors.h" typedef double boxtensor_t ; typedef double compute_t; @@ -24,7 +23,10 @@ class NeighborStatOp : public OpKernel { } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& coord_tensor = context->input(context_input_index++); @@ -62,7 +64,7 @@ class NeighborStatOp : public OpKernel { nei_mode = -1; } else { - throw deepmd::deepmd_exception("invalid mesh tensor"); + throw std::runtime_error("invalid mesh tensor"); } // if region is given extended, do not use pbc bool b_pbc = (nei_mode >= 1 || nei_mode == -1) ? false : true; @@ -141,7 +143,7 @@ class NeighborStatOp : public OpKernel { ::build_nlist (d_nlist_a, d_nlist_r, d_coord3, -1, rcut, NULL); } else { - throw deepmd::deepmd_exception("unknow neighbor mode"); + throw std::runtime_error("unknow neighbor mode"); } int MAX_NNEI = 0; @@ -169,17 +171,6 @@ class NeighborStatOp : public OpKernel { min_nbor_dist[ii * MAX_NNEI + jj] = sqrt(rij[0] * rij[0] + rij[1] * rij[1] + rij[2] * rij[2]); } } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: diff --git a/source/op/pair_tab.cc b/source/op/pair_tab.cc index e09ef460b4..eccf3e001f 100644 --- a/source/op/pair_tab.cc +++ b/source/op/pair_tab.cc @@ -34,6 +34,10 @@ class PairTabOp : public OpKernel { } void Compute(OpKernelContext* context) override { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int tmp_idx = 0; const Tensor& table_info_tensor = context->input(tmp_idx++); diff --git a/source/op/prod_env_mat_multi_device.cc b/source/op/prod_env_mat_multi_device.cc index 43235ae488..293580fad6 100644 --- a/source/op/prod_env_mat_multi_device.cc +++ b/source/op/prod_env_mat_multi_device.cc @@ -4,7 +4,6 @@ #include "region.h" #include "neighbor_list.h" #include "prod_env_mat.h" -#include "errors.h" REGISTER_OP("ProdEnvMatA") .Attr("T: {float, double} = DT_DOUBLE") @@ -322,7 +321,10 @@ class ProdEnvMatAOp : public OpKernel { } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& coord_tensor = context->input(context_input_index++); @@ -384,7 +386,7 @@ class ProdEnvMatAOp : public OpKernel { nei_mode = -1; } else { - throw deepmd::deepmd_exception("invalid mesh tensor"); + throw std::runtime_error("invalid mesh tensor"); } // Create output tensors @@ -544,19 +546,6 @@ class ProdEnvMatAOp : public OpKernel { if(b_nlist_map) _map_nlist_cpu(nlist, &idx_mapping[0], nloc, nnei); } } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - // Ref: - // https://github.com/tensorflow/tensorflow/blob/5dcfc51118817f27fad5246812d83e5dccdc5f72/tensorflow/core/kernels/mkl/mkl_tfconv_op.h#L120-L126 - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } ///////////////////////////////////////////////////////////////////////////////////////////// @@ -599,7 +588,10 @@ class ProdEnvMatROp : public OpKernel { } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& coord_tensor = context->input(context_input_index++); @@ -658,7 +650,7 @@ class ProdEnvMatROp : public OpKernel { nei_mode = -1; } else { - throw deepmd::deepmd_exception("invalid mesh tensor"); + throw std::runtime_error("invalid mesh tensor"); } // Create an output tensor @@ -818,17 +810,6 @@ class ProdEnvMatROp : public OpKernel { if(b_nlist_map) _map_nlist_cpu(nlist, &idx_mapping[0], nloc, nnei); } } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } ///////////////////////////////////////////////////////////////////////////////////////////// diff --git a/source/op/prod_force.cc b/source/op/prod_force.cc index 427624691a..46cfcfceda 100644 --- a/source/op/prod_force.cc +++ b/source/op/prod_force.cc @@ -1,5 +1,4 @@ #include "custom_op.h" -#include "errors.h" REGISTER_OP("ProdForce") .Attr("T: {float, double} = DT_DOUBLE") @@ -27,7 +26,10 @@ class ProdForceOp : public OpKernel { } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor const Tensor& net_deriv_tensor = context->input(0); const Tensor& in_deriv_tensor = context->input(1); @@ -141,17 +143,6 @@ class ProdForceOp : public OpKernel { } } } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: int n_r_sel, n_a_sel, n_a_shift; diff --git a/source/op/prod_force_grad.cc b/source/op/prod_force_grad.cc index b49a50be53..d406e3e320 100644 --- a/source/op/prod_force_grad.cc +++ b/source/op/prod_force_grad.cc @@ -1,5 +1,4 @@ #include "custom_op.h" -#include "errors.h" REGISTER_OP("ProdForceGrad") .Attr("T: {float, double} = DT_DOUBLE") @@ -26,7 +25,10 @@ class ProdForceGradOp : public OpKernel } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor const Tensor& grad_tensor = context->input(0); const Tensor& net_deriv_tensor = context->input(1); @@ -153,17 +155,6 @@ class ProdForceGradOp : public OpKernel } } } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: int n_r_sel, n_a_sel, n_a_shift; diff --git a/source/op/prod_force_grad_multi_device.cc b/source/op/prod_force_grad_multi_device.cc index e29ab01a5e..497b7945f0 100644 --- a/source/op/prod_force_grad_multi_device.cc +++ b/source/op/prod_force_grad_multi_device.cc @@ -1,6 +1,5 @@ #include "custom_op.h" #include "prod_force_grad.h" -#include "errors.h" REGISTER_OP("ProdForceSeAGrad") .Attr("T: {float, double} = DT_DOUBLE") @@ -32,7 +31,10 @@ class ProdForceSeAGradOp : public OpKernel { } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& grad_tensor = context->input(context_input_index++); @@ -128,17 +130,6 @@ class ProdForceSeAGradOp : public OpKernel { grad, in_deriv, nlist, nloc, nnei); } } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: std::string device; @@ -152,7 +143,10 @@ class ProdForceSeRGradOp : public OpKernel explicit ProdForceSeRGradOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& grad_tensor = context->input(context_input_index++); @@ -248,17 +242,6 @@ class ProdForceSeRGradOp : public OpKernel grad, in_deriv, nlist, nloc, nnei); } } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: std::string device; diff --git a/source/op/prod_force_multi_device.cc b/source/op/prod_force_multi_device.cc index e08e3a4dc1..94b859bca9 100644 --- a/source/op/prod_force_multi_device.cc +++ b/source/op/prod_force_multi_device.cc @@ -1,6 +1,5 @@ #include "custom_op.h" #include "prod_force.h" -#include "errors.h" REGISTER_OP("ProdForceSeA") .Attr("T: {float, double} = DT_DOUBLE") @@ -26,7 +25,10 @@ class ProdForceSeAOp : public OpKernel { explicit ProdForceSeAOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& net_deriv_tensor = context->input(context_input_index++); @@ -103,17 +105,6 @@ class ProdForceSeAOp : public OpKernel { net_deriv, in_deriv, nlist, nloc, nall, nnei); } } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: std::string device; @@ -124,7 +115,6 @@ class ProdForceSeROp : public OpKernel { public: explicit ProdForceSeROp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - try { // Grab the input tensor int context_input_index = 0; const Tensor& net_deriv_tensor = context->input(context_input_index++); @@ -200,17 +190,6 @@ class ProdForceSeROp : public OpKernel { net_deriv, in_deriv, nlist, nloc, nall, nnei); } } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: std::string device; diff --git a/source/op/prod_force_se_a_grad.cc b/source/op/prod_force_se_a_grad.cc index 0dbc331437..21eab7b2ce 100644 --- a/source/op/prod_force_se_a_grad.cc +++ b/source/op/prod_force_se_a_grad.cc @@ -1,6 +1,5 @@ #include "custom_op.h" #include "prod_force_grad.h" -#include "errors.h" REGISTER_OP("ProdForceSeAGrad") .Attr("T: {float, double} = DT_DOUBLE") @@ -26,7 +25,10 @@ class ProdForceSeAGradOp : public OpKernel } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& grad_tensor = context->input(context_input_index++); @@ -97,17 +99,6 @@ class ProdForceSeAGradOp : public OpKernel nloc, nnei); } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: int n_r_sel, n_a_sel, n_a_shift; diff --git a/source/op/prod_force_se_r_grad.cc b/source/op/prod_force_se_r_grad.cc index 0e4a53f113..5ccdf2431c 100644 --- a/source/op/prod_force_se_r_grad.cc +++ b/source/op/prod_force_se_r_grad.cc @@ -1,6 +1,5 @@ #include "custom_op.h" #include "prod_force_grad.h" -#include "errors.h" REGISTER_OP("ProdForceSeRGrad") .Attr("T: {float, double} = DT_DOUBLE") @@ -21,7 +20,10 @@ class ProdForceSeRGradOp : public OpKernel } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& grad_tensor = context->input(context_input_index++); @@ -91,17 +93,6 @@ class ProdForceSeRGradOp : public OpKernel nloc, nnei); } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } }; diff --git a/source/op/prod_virial.cc b/source/op/prod_virial.cc index 0e3cecd258..4dfc4d6824 100644 --- a/source/op/prod_virial.cc +++ b/source/op/prod_virial.cc @@ -1,5 +1,4 @@ #include "custom_op.h" -#include "errors.h" REGISTER_OP("ProdVirial") .Attr("T: {float, double} = DT_DOUBLE") @@ -29,7 +28,10 @@ class ProdVirialOp : public OpKernel { } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor const Tensor& net_deriv_tensor = context->input(0); const Tensor& in_deriv_tensor = context->input(1); @@ -162,17 +164,6 @@ class ProdVirialOp : public OpKernel { } } } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: int n_r_sel, n_a_sel, n_a_shift; diff --git a/source/op/prod_virial_grad.cc b/source/op/prod_virial_grad.cc index fac919872e..a644da56ed 100644 --- a/source/op/prod_virial_grad.cc +++ b/source/op/prod_virial_grad.cc @@ -1,5 +1,4 @@ #include "custom_op.h" -#include "errors.h" REGISTER_OP("ProdVirialGrad") .Attr("T: {float, double} = DT_DOUBLE") @@ -27,7 +26,10 @@ class ProdVirialGradOp : public OpKernel } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor const Tensor& grad_tensor = context->input(0); const Tensor& net_deriv_tensor = context->input(1); @@ -162,17 +164,6 @@ class ProdVirialGradOp : public OpKernel } } } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: int n_r_sel, n_a_sel, n_a_shift; diff --git a/source/op/prod_virial_grad_multi_device.cc b/source/op/prod_virial_grad_multi_device.cc index 3bd94864c1..e62cb5ac1a 100644 --- a/source/op/prod_virial_grad_multi_device.cc +++ b/source/op/prod_virial_grad_multi_device.cc @@ -1,6 +1,5 @@ #include "custom_op.h" #include "prod_virial_grad.h" -#include "errors.h" REGISTER_OP("ProdVirialSeAGrad") .Attr("T: {float, double} = DT_DOUBLE") @@ -35,7 +34,10 @@ class ProdVirialSeAGradOp : public OpKernel } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& grad_tensor = context->input(context_input_index++); @@ -142,17 +144,6 @@ class ProdVirialSeAGradOp : public OpKernel grad, in_deriv, rij, nlist, nloc, nnei); } } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: std::string device; @@ -166,7 +157,10 @@ class ProdVirialSeRGradOp : public OpKernel explicit ProdVirialSeRGradOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& grad_tensor = context->input(context_input_index++); @@ -272,17 +266,6 @@ class ProdVirialSeRGradOp : public OpKernel grad, in_deriv, rij, nlist, nloc, nnei); } } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: std::string device; diff --git a/source/op/prod_virial_multi_device.cc b/source/op/prod_virial_multi_device.cc index b3e72e0239..dbb8105edc 100644 --- a/source/op/prod_virial_multi_device.cc +++ b/source/op/prod_virial_multi_device.cc @@ -1,6 +1,5 @@ #include "custom_op.h" #include "prod_virial.h" -#include "errors.h" REGISTER_OP("ProdVirialSeA") .Attr("T: {float, double} = DT_DOUBLE") @@ -29,7 +28,10 @@ class ProdVirialSeAOp : public OpKernel { public: explicit ProdVirialSeAOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& net_deriv_tensor = context->input(context_input_index++); @@ -112,17 +114,6 @@ class ProdVirialSeAOp : public OpKernel { net_deriv, in_deriv, rij, nlist, nloc, nall, nnei); } } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: std::string device; @@ -133,7 +124,10 @@ class ProdVirialSeROp : public OpKernel { public: explicit ProdVirialSeROp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& net_deriv_tensor = context->input(context_input_index++); @@ -216,17 +210,6 @@ class ProdVirialSeROp : public OpKernel { net_deriv, in_deriv, rij, nlist, nloc, nall, nnei); } } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: std::string device; diff --git a/source/op/prod_virial_se_a_grad.cc b/source/op/prod_virial_se_a_grad.cc index 8adefbf5db..3df66beca3 100644 --- a/source/op/prod_virial_se_a_grad.cc +++ b/source/op/prod_virial_se_a_grad.cc @@ -1,6 +1,5 @@ #include "custom_op.h" #include "prod_virial_grad.h" -#include "errors.h" REGISTER_OP("ProdVirialSeAGrad") .Attr("T: {float, double} = DT_DOUBLE") @@ -27,7 +26,10 @@ class ProdVirialSeAGradOp : public OpKernel } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& grad_tensor = context->input(context_input_index++); @@ -106,17 +108,6 @@ class ProdVirialSeAGradOp : public OpKernel nloc, nnei); } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: int n_r_sel, n_a_sel, n_a_shift; diff --git a/source/op/prod_virial_se_r_grad.cc b/source/op/prod_virial_se_r_grad.cc index 7966a21a91..53910c991e 100644 --- a/source/op/prod_virial_se_r_grad.cc +++ b/source/op/prod_virial_se_r_grad.cc @@ -1,6 +1,5 @@ #include "custom_op.h" #include "prod_virial_grad.h" -#include "errors.h" REGISTER_OP("ProdVirialSeRGrad") .Attr("T: {float, double} = DT_DOUBLE") @@ -22,7 +21,10 @@ class ProdVirialSeRGradOp : public OpKernel } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& grad_tensor = context->input(context_input_index++); @@ -100,17 +102,6 @@ class ProdVirialSeRGradOp : public OpKernel nloc, nnei); } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } }; diff --git a/source/op/soft_min.cc b/source/op/soft_min.cc index 18fe2aab84..524fdaf4ef 100644 --- a/source/op/soft_min.cc +++ b/source/op/soft_min.cc @@ -1,7 +1,6 @@ #include "custom_op.h" #include "ComputeDescriptor.h" #include "soft_min_switch.h" -#include "errors.h" REGISTER_OP("SoftMinSwitch") .Attr("T: {float, double} = DT_DOUBLE") @@ -38,7 +37,10 @@ class SoftMinSwitchOp : public OpKernel { } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int tmp_idx = 0; const Tensor& type_tensor = context->input(tmp_idx++); @@ -104,17 +106,6 @@ class SoftMinSwitchOp : public OpKernel { rmin, rmax); } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: std::vector sel_r; diff --git a/source/op/soft_min_force.cc b/source/op/soft_min_force.cc index 2d944018c2..fbd9bed1e3 100644 --- a/source/op/soft_min_force.cc +++ b/source/op/soft_min_force.cc @@ -1,6 +1,5 @@ #include "custom_op.h" #include "soft_min_switch_force.h" -#include "errors.h" REGISTER_OP("SoftMinForce") .Attr("T: {float, double} = DT_DOUBLE") @@ -25,7 +24,10 @@ class SoftMinForceOp : public OpKernel { } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor const Tensor& du_tensor = context->input(0); const Tensor& sw_deriv_tensor = context->input(1); @@ -79,17 +81,6 @@ class SoftMinForceOp : public OpKernel { nall, nnei); } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: int n_r_sel, n_a_sel; diff --git a/source/op/soft_min_force_grad.cc b/source/op/soft_min_force_grad.cc index c7b882e0f0..2bb00b30b6 100644 --- a/source/op/soft_min_force_grad.cc +++ b/source/op/soft_min_force_grad.cc @@ -1,6 +1,5 @@ #include "custom_op.h" #include "soft_min_switch_force_grad.h" -#include "errors.h" REGISTER_OP("SoftMinForceGrad") .Attr("T: {float, double} = DT_DOUBLE") @@ -25,7 +24,10 @@ class SoftMinForceGradOp : public OpKernel } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& grad_tensor = context->input(context_input_index++); @@ -90,17 +92,6 @@ class SoftMinForceGradOp : public OpKernel nloc, nnei); } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: int n_r_sel, n_a_sel; diff --git a/source/op/soft_min_virial.cc b/source/op/soft_min_virial.cc index c4a830c7f1..2f9ab3a149 100644 --- a/source/op/soft_min_virial.cc +++ b/source/op/soft_min_virial.cc @@ -1,6 +1,5 @@ #include "custom_op.h" #include "soft_min_switch_virial.h" -#include "errors.h" REGISTER_OP("SoftMinVirial") .Attr("T: {float, double} = DT_DOUBLE") @@ -27,7 +26,10 @@ class SoftMinVirialOp : public OpKernel { } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& du_tensor = context->input(context_input_index++); @@ -95,17 +97,6 @@ class SoftMinVirialOp : public OpKernel { nall, nnei); } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: int n_r_sel, n_a_sel; diff --git a/source/op/soft_min_virial_grad.cc b/source/op/soft_min_virial_grad.cc index b73ed81d65..fe6c944692 100644 --- a/source/op/soft_min_virial_grad.cc +++ b/source/op/soft_min_virial_grad.cc @@ -1,6 +1,5 @@ #include "custom_op.h" #include "soft_min_switch_virial_grad.h" -#include "errors.h" REGISTER_OP("SoftMinVirialGrad") .Attr("T: {float, double} = DT_DOUBLE") @@ -26,7 +25,10 @@ class SoftMinVirialGradOp : public OpKernel } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& grad_tensor = context->input(context_input_index++); @@ -98,17 +100,6 @@ class SoftMinVirialGradOp : public OpKernel nloc, nnei); } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: int n_r_sel, n_a_sel, n_a_shift; diff --git a/source/op/tabulate_multi_device.cc b/source/op/tabulate_multi_device.cc index 11b5b15250..afbc736dc5 100644 --- a/source/op/tabulate_multi_device.cc +++ b/source/op/tabulate_multi_device.cc @@ -1,6 +1,5 @@ #include "custom_op.h" #include "tabulate.h" -#include "errors.h" REGISTER_OP("TabulateFusion") .Attr("T: {float, double} = DT_DOUBLE") @@ -29,7 +28,10 @@ class TabulateFusionOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("last_layer_size", &last_layer_size)); } void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& table_tensor = context->input(context_input_index++); @@ -81,17 +83,6 @@ class TabulateFusionOp : public OpKernel { descriptor, table, table_info, em_x, em, nloc, nnei, last_layer_size); } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: int last_layer_size; @@ -103,7 +94,10 @@ class TabulateFusionGradOp : public OpKernel { public: explicit TabulateFusionGradOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& table_tensor = context->input(context_input_index++); @@ -161,17 +155,6 @@ class TabulateFusionGradOp : public OpKernel { dy_dem_x, dy_dem, table, table_info, em_x, em, dy, nloc, nnei, last_layer_size); } - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: std::string device; diff --git a/source/op/unaggregated_grad.cc b/source/op/unaggregated_grad.cc index 44bf9620a8..1c719ed7b9 100644 --- a/source/op/unaggregated_grad.cc +++ b/source/op/unaggregated_grad.cc @@ -1,7 +1,6 @@ #include "custom_op.h" #include "ComputeDescriptor.h" #include "neighbor_list.h" -#include "errors.h" REGISTER_OP("UnaggregatedDyDxS") .Attr("T: {float, double} = DT_DOUBLE") @@ -137,7 +136,10 @@ class UnaggregatedDyDxSOp : public OpKernel { explicit UnaggregatedDyDxSOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& y = context->input(context_input_index++); @@ -161,17 +163,6 @@ class UnaggregatedDyDxSOp : public OpKernel { y.shape().dim_size(1), dy_dx->flat().data() ); - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: }; @@ -182,7 +173,10 @@ class UnaggregatedDy2DxSOp : public OpKernel { explicit UnaggregatedDy2DxSOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& y = context->input(context_input_index++); @@ -209,17 +203,6 @@ class UnaggregatedDy2DxSOp : public OpKernel { y.shape().dim_size(1), dy2_dx->flat().data() ); - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: }; @@ -230,7 +213,10 @@ class UnaggregatedDyDxOp : public OpKernel { explicit UnaggregatedDyDxOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& z = context->input(context_input_index++); @@ -258,17 +244,6 @@ class UnaggregatedDyDxOp : public OpKernel { w.shape().dim_size(0), dz_dx->flat().data() ); - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: }; @@ -279,7 +254,10 @@ class UnaggregatedDy2DxOp : public OpKernel { explicit UnaggregatedDy2DxOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - try { + deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + } + + void _Compute(OpKernelContext* context) { // Grab the input tensor int context_input_index = 0; const Tensor& z = context->input(context_input_index++); @@ -313,17 +291,6 @@ class UnaggregatedDy2DxOp : public OpKernel { w.shape().dim_size(0), dz2_dx->flat().data() ); - } catch (deepmd::deepmd_exception_oom& e){ - OP_REQUIRES_OK( - context, - errors::ResourceExhausted("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } catch (deepmd::deepmd_exception& e) { - OP_REQUIRES_OK( - context, - errors::Internal("Operation received an exception: ", e.what(), - ", in file ",__FILE__, ":", __LINE__)); - } } private: }; From 3b57006b22e56e208e11a3a405e07bfee69f745a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 4 Aug 2021 05:23:03 -0400 Subject: [PATCH 3/6] replace std::runtime_error --- source/op/descrpt.cc | 4 ++-- source/op/descrpt_se_a_ef.cc | 4 ++-- source/op/descrpt_se_a_ef_para.cc | 4 ++-- source/op/descrpt_se_a_ef_vert.cc | 4 ++-- source/op/neighbor_stat.cc | 4 ++-- source/op/prod_env_mat_multi_device.cc | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/source/op/descrpt.cc b/source/op/descrpt.cc index df3efd5246..91ff45c974 100644 --- a/source/op/descrpt.cc +++ b/source/op/descrpt.cc @@ -109,7 +109,7 @@ class DescrptOp : public OpKernel { nei_mode = -1; } else { - throw std::runtime_error("invalid mesh tensor"); + throw deepmd::deepmd_exception("invalid mesh tensor"); } bool b_pbc = true; // if region is given extended, do not use pbc @@ -258,7 +258,7 @@ class DescrptOp : public OpKernel { ::build_nlist (d_nlist_a, d_nlist_r, d_coord3, rcut_a, rcut_r, NULL); } else { - throw std::runtime_error("unknow neighbor mode"); + throw deepmd::deepmd_exception("unknow neighbor mode"); } // loop over atoms, compute descriptors for each atom diff --git a/source/op/descrpt_se_a_ef.cc b/source/op/descrpt_se_a_ef.cc index 8edfcb45dc..2cf7d05f0c 100644 --- a/source/op/descrpt_se_a_ef.cc +++ b/source/op/descrpt_se_a_ef.cc @@ -116,7 +116,7 @@ class DescrptSeAEfOp : public OpKernel { nei_mode = -1; } else { - throw std::runtime_error("invalid mesh tensor"); + throw deepmd::deepmd_exception("invalid mesh tensor"); } bool b_pbc = true; // if region is given extended, do not use pbc @@ -271,7 +271,7 @@ class DescrptSeAEfOp : public OpKernel { ::build_nlist (d_nlist_a, d_nlist_r, d_coord3, rcut_a, rcut_r, NULL); } else { - throw std::runtime_error("unknow neighbor mode"); + throw deepmd::deepmd_exception("unknow neighbor mode"); } // loop over atoms, compute descriptors for each atom diff --git a/source/op/descrpt_se_a_ef_para.cc b/source/op/descrpt_se_a_ef_para.cc index 65bd3a9269..0d33b5fdb6 100644 --- a/source/op/descrpt_se_a_ef_para.cc +++ b/source/op/descrpt_se_a_ef_para.cc @@ -115,7 +115,7 @@ class DescrptSeAEfParaOp : public OpKernel { nei_mode = -1; } else { - throw std::runtime_error("invalid mesh tensor"); + throw deepmd::deepmd_exception("invalid mesh tensor"); } bool b_pbc = true; // if region is given extended, do not use pbc @@ -270,7 +270,7 @@ class DescrptSeAEfParaOp : public OpKernel { ::build_nlist (d_nlist_a, d_nlist_r, d_coord3, rcut_a, rcut_r, NULL); } else { - throw std::runtime_error("unknow neighbor mode"); + throw deepmd::deepmd_exception("unknow neighbor mode"); } // loop over atoms, compute descriptors for each atom diff --git a/source/op/descrpt_se_a_ef_vert.cc b/source/op/descrpt_se_a_ef_vert.cc index 1c898886c7..8bcc974ab4 100644 --- a/source/op/descrpt_se_a_ef_vert.cc +++ b/source/op/descrpt_se_a_ef_vert.cc @@ -115,7 +115,7 @@ class DescrptSeAEfVertOp : public OpKernel { nei_mode = -1; } else { - throw std::runtime_error("invalid mesh tensor"); + throw deepmd::deepmd_exception("invalid mesh tensor"); } bool b_pbc = true; // if region is given extended, do not use pbc @@ -270,7 +270,7 @@ class DescrptSeAEfVertOp : public OpKernel { ::build_nlist (d_nlist_a, d_nlist_r, d_coord3, rcut_a, rcut_r, NULL); } else { - throw std::runtime_error("unknow neighbor mode"); + throw deepmd::deepmd_exception("unknow neighbor mode"); } // loop over atoms, compute descriptors for each atom diff --git a/source/op/neighbor_stat.cc b/source/op/neighbor_stat.cc index 8aab137892..91da924f5d 100644 --- a/source/op/neighbor_stat.cc +++ b/source/op/neighbor_stat.cc @@ -64,7 +64,7 @@ class NeighborStatOp : public OpKernel { nei_mode = -1; } else { - throw std::runtime_error("invalid mesh tensor"); + throw deepmd::deepmd_exception("invalid mesh tensor"); } // if region is given extended, do not use pbc bool b_pbc = (nei_mode >= 1 || nei_mode == -1) ? false : true; @@ -143,7 +143,7 @@ class NeighborStatOp : public OpKernel { ::build_nlist (d_nlist_a, d_nlist_r, d_coord3, -1, rcut, NULL); } else { - throw std::runtime_error("unknow neighbor mode"); + throw deepmd::deepmd_exception("unknow neighbor mode"); } int MAX_NNEI = 0; diff --git a/source/op/prod_env_mat_multi_device.cc b/source/op/prod_env_mat_multi_device.cc index 293580fad6..82008bfcb8 100644 --- a/source/op/prod_env_mat_multi_device.cc +++ b/source/op/prod_env_mat_multi_device.cc @@ -386,7 +386,7 @@ class ProdEnvMatAOp : public OpKernel { nei_mode = -1; } else { - throw std::runtime_error("invalid mesh tensor"); + throw deepmd::deepmd_exception("invalid mesh tensor"); } // Create output tensors @@ -650,7 +650,7 @@ class ProdEnvMatROp : public OpKernel { nei_mode = -1; } else { - throw std::runtime_error("invalid mesh tensor"); + throw deepmd::deepmd_exception("invalid mesh tensor"); } // Create an output tensor From e70d780538f38d503be9e1e7e746f3e39e051457 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 4 Aug 2021 05:24:50 -0400 Subject: [PATCH 4/6] add headers --- source/op/descrpt.cc | 1 + source/op/descrpt_se_a_ef.cc | 1 + source/op/descrpt_se_a_ef_para.cc | 1 + source/op/descrpt_se_a_ef_vert.cc | 1 + source/op/legacy/descrpt_se_a.cc | 1 + source/op/legacy/descrpt_se_r.cc | 1 + source/op/neighbor_stat.cc | 1 + source/op/prod_env_mat_multi_device.cc | 1 + 8 files changed, 8 insertions(+) diff --git a/source/op/descrpt.cc b/source/op/descrpt.cc index 91ff45c974..77c5004941 100644 --- a/source/op/descrpt.cc +++ b/source/op/descrpt.cc @@ -2,6 +2,7 @@ #include "ComputeDescriptor.h" #include "neighbor_list.h" #include "fmt_nlist.h" +#include "errors.h" typedef double boxtensor_t ; typedef double compute_t; diff --git a/source/op/descrpt_se_a_ef.cc b/source/op/descrpt_se_a_ef.cc index 2cf7d05f0c..9cdbb2a1c6 100644 --- a/source/op/descrpt_se_a_ef.cc +++ b/source/op/descrpt_se_a_ef.cc @@ -3,6 +3,7 @@ #include "ComputeDescriptor.h" #include "neighbor_list.h" #include "fmt_nlist.h" +#include "errors.h" typedef double boxtensor_t ; typedef double compute_t; diff --git a/source/op/descrpt_se_a_ef_para.cc b/source/op/descrpt_se_a_ef_para.cc index 0d33b5fdb6..33923ed829 100644 --- a/source/op/descrpt_se_a_ef_para.cc +++ b/source/op/descrpt_se_a_ef_para.cc @@ -2,6 +2,7 @@ #include "ComputeDescriptor.h" #include "neighbor_list.h" #include "fmt_nlist.h" +#include "errors.h" typedef double boxtensor_t ; typedef double compute_t; diff --git a/source/op/descrpt_se_a_ef_vert.cc b/source/op/descrpt_se_a_ef_vert.cc index 8bcc974ab4..5aa6376e95 100644 --- a/source/op/descrpt_se_a_ef_vert.cc +++ b/source/op/descrpt_se_a_ef_vert.cc @@ -2,6 +2,7 @@ #include "ComputeDescriptor.h" #include "neighbor_list.h" #include "fmt_nlist.h" +#include "errors.h" typedef double boxtensor_t ; typedef double compute_t; diff --git a/source/op/legacy/descrpt_se_a.cc b/source/op/legacy/descrpt_se_a.cc index 5bbfa6d0c9..cd7abf8a76 100644 --- a/source/op/legacy/descrpt_se_a.cc +++ b/source/op/legacy/descrpt_se_a.cc @@ -3,6 +3,7 @@ #include "neighbor_list.h" #include "fmt_nlist.h" #include "env_mat.h" +#include "errors.h" typedef double boxtensor_t ; typedef double compute_t; diff --git a/source/op/legacy/descrpt_se_r.cc b/source/op/legacy/descrpt_se_r.cc index b811431922..408818fbee 100644 --- a/source/op/legacy/descrpt_se_r.cc +++ b/source/op/legacy/descrpt_se_r.cc @@ -3,6 +3,7 @@ #include "neighbor_list.h" #include "fmt_nlist.h" #include "env_mat.h" +#include "errors.h" typedef double boxtensor_t ; typedef double compute_t; diff --git a/source/op/neighbor_stat.cc b/source/op/neighbor_stat.cc index 91da924f5d..52eeff8b7a 100644 --- a/source/op/neighbor_stat.cc +++ b/source/op/neighbor_stat.cc @@ -1,5 +1,6 @@ #include "custom_op.h" #include "neighbor_list.h" +#include "errors.h" typedef double boxtensor_t ; typedef double compute_t; diff --git a/source/op/prod_env_mat_multi_device.cc b/source/op/prod_env_mat_multi_device.cc index 82008bfcb8..9e0a7abb06 100644 --- a/source/op/prod_env_mat_multi_device.cc +++ b/source/op/prod_env_mat_multi_device.cc @@ -4,6 +4,7 @@ #include "region.h" #include "neighbor_list.h" #include "prod_env_mat.h" +#include "errors.h" REGISTER_OP("ProdEnvMatA") .Attr("T: {float, double} = DT_DOUBLE") From 9c8a0da4a91b8d060be312cafa82c379ccc8fc37 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 4 Aug 2021 05:26:15 -0400 Subject: [PATCH 5/6] clean useless line --- source/op/custom_op.h | 1 - 1 file changed, 1 deletion(-) diff --git a/source/op/custom_op.h b/source/op/custom_op.h index ce78449e77..08838cda7e 100644 --- a/source/op/custom_op.h +++ b/source/op/custom_op.h @@ -29,6 +29,5 @@ struct DeviceFunctor { }; namespace deepmd { - typedef void ComputeFunction(OpKernelContext*); void save_compute(OpKernelContext* context, std::function ff); }; \ No newline at end of file From 01cad22a1d5ae1fdfc7dfddcf68feebdefdbe897 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 4 Aug 2021 15:48:26 -0400 Subject: [PATCH 6/6] add custom_op.cc to api_cc tests and rename save_compute to safe_compute --- source/api_cc/tests/CMakeLists.txt | 2 +- source/op/custom_op.cc | 2 +- source/op/custom_op.h | 2 +- source/op/descrpt.cc | 2 +- source/op/descrpt_se_a_ef.cc | 2 +- source/op/descrpt_se_a_ef_para.cc | 2 +- source/op/descrpt_se_a_ef_vert.cc | 2 +- source/op/ewald_recp.cc | 2 +- source/op/gelu_multi_device.cc | 6 +++--- source/op/map_aparam.cc | 2 +- source/op/neighbor_stat.cc | 2 +- source/op/pair_tab.cc | 2 +- source/op/prod_env_mat_multi_device.cc | 4 ++-- source/op/prod_force.cc | 2 +- source/op/prod_force_grad.cc | 2 +- source/op/prod_force_grad_multi_device.cc | 4 ++-- source/op/prod_force_multi_device.cc | 2 +- source/op/prod_force_se_a_grad.cc | 2 +- source/op/prod_force_se_r_grad.cc | 2 +- source/op/prod_virial.cc | 2 +- source/op/prod_virial_grad.cc | 2 +- source/op/prod_virial_grad_multi_device.cc | 4 ++-- source/op/prod_virial_multi_device.cc | 4 ++-- source/op/prod_virial_se_a_grad.cc | 2 +- source/op/prod_virial_se_r_grad.cc | 2 +- source/op/soft_min.cc | 2 +- source/op/soft_min_force.cc | 2 +- source/op/soft_min_force_grad.cc | 2 +- source/op/soft_min_virial.cc | 2 +- source/op/soft_min_virial_grad.cc | 2 +- source/op/tabulate_multi_device.cc | 4 ++-- source/op/unaggregated_grad.cc | 8 ++++---- 32 files changed, 42 insertions(+), 42 deletions(-) diff --git a/source/api_cc/tests/CMakeLists.txt b/source/api_cc/tests/CMakeLists.txt index 6768ff2ee6..1a5b56fca0 100644 --- a/source/api_cc/tests/CMakeLists.txt +++ b/source/api_cc/tests/CMakeLists.txt @@ -37,7 +37,7 @@ configure_file( set(opname "deepmd_op") set(OP_BASE_DIR ${CMAKE_SOURCE_DIR}/../../op) # file(GLOB OP_SRC ${OP_BASE_DIR}/*.cc) -file(GLOB OP_SRC ${OP_BASE_DIR}/prod_force.cc ${OP_BASE_DIR}/prod_virial.cc ${OP_BASE_DIR}/descrpt.cc ${OP_BASE_DIR}/descrpt_se_a_ef.cc ${OP_BASE_DIR}/descrpt_se_a_ef.cc ${OP_BASE_DIR}/descrpt_se_a_ef_para.cc ${OP_BASE_DIR}/descrpt_se_a_ef_vert.cc ${OP_BASE_DIR}/pair_tab.cc ${OP_BASE_DIR}/prod_force_multi_device.cc ${OP_BASE_DIR}/prod_virial_multi_device.cc ${OP_BASE_DIR}/soft_min.cc ${OP_BASE_DIR}/soft_min_force.cc ${OP_BASE_DIR}/soft_min_virial.cc ${OP_BASE_DIR}/ewald_recp.cc ${OP_BASE_DIR}/gelu_multi_device.cc ${OP_BASE_DIR}/map_aparam.cc ${OP_BASE_DIR}/neighbor_stat.cc ${OP_BASE_DIR}/unaggregated_grad.cc ${OP_BASE_DIR}/tabulate_multi_device.cc ${OP_BASE_DIR}/prod_env_mat_multi_device.cc) +file(GLOB OP_SRC ${OP_BASE_DIR}/custom_op.cc ${OP_BASE_DIR}/prod_force.cc ${OP_BASE_DIR}/prod_virial.cc ${OP_BASE_DIR}/descrpt.cc ${OP_BASE_DIR}/descrpt_se_a_ef.cc ${OP_BASE_DIR}/descrpt_se_a_ef.cc ${OP_BASE_DIR}/descrpt_se_a_ef_para.cc ${OP_BASE_DIR}/descrpt_se_a_ef_vert.cc ${OP_BASE_DIR}/pair_tab.cc ${OP_BASE_DIR}/prod_force_multi_device.cc ${OP_BASE_DIR}/prod_virial_multi_device.cc ${OP_BASE_DIR}/soft_min.cc ${OP_BASE_DIR}/soft_min_force.cc ${OP_BASE_DIR}/soft_min_virial.cc ${OP_BASE_DIR}/ewald_recp.cc ${OP_BASE_DIR}/gelu_multi_device.cc ${OP_BASE_DIR}/map_aparam.cc ${OP_BASE_DIR}/neighbor_stat.cc ${OP_BASE_DIR}/unaggregated_grad.cc ${OP_BASE_DIR}/tabulate_multi_device.cc ${OP_BASE_DIR}/prod_env_mat_multi_device.cc) add_library(${opname} SHARED ${OP_SRC}) list (APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/../../cmake/) diff --git a/source/op/custom_op.cc b/source/op/custom_op.cc index 80df56604b..741fb3ace6 100644 --- a/source/op/custom_op.cc +++ b/source/op/custom_op.cc @@ -2,7 +2,7 @@ #include "errors.h" namespace deepmd { - void save_compute(OpKernelContext* context, std::function ff) { + void safe_compute(OpKernelContext* context, std::function ff) { try{ ff(context); } catch (deepmd::deepmd_exception_oom& e){ diff --git a/source/op/custom_op.h b/source/op/custom_op.h index 08838cda7e..8482e92b03 100644 --- a/source/op/custom_op.h +++ b/source/op/custom_op.h @@ -29,5 +29,5 @@ struct DeviceFunctor { }; namespace deepmd { - void save_compute(OpKernelContext* context, std::function ff); + void safe_compute(OpKernelContext* context, std::function ff); }; \ No newline at end of file diff --git a/source/op/descrpt.cc b/source/op/descrpt.cc index 77c5004941..7fdf81d986 100644 --- a/source/op/descrpt.cc +++ b/source/op/descrpt.cc @@ -50,7 +50,7 @@ class DescrptOp : public OpKernel { } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/descrpt_se_a_ef.cc b/source/op/descrpt_se_a_ef.cc index 9cdbb2a1c6..121205c9cf 100644 --- a/source/op/descrpt_se_a_ef.cc +++ b/source/op/descrpt_se_a_ef.cc @@ -50,7 +50,7 @@ class DescrptSeAEfOp : public OpKernel { } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/descrpt_se_a_ef_para.cc b/source/op/descrpt_se_a_ef_para.cc index 33923ed829..952c53d473 100644 --- a/source/op/descrpt_se_a_ef_para.cc +++ b/source/op/descrpt_se_a_ef_para.cc @@ -49,7 +49,7 @@ class DescrptSeAEfParaOp : public OpKernel { } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/descrpt_se_a_ef_vert.cc b/source/op/descrpt_se_a_ef_vert.cc index 5aa6376e95..4ef76f8e0f 100644 --- a/source/op/descrpt_se_a_ef_vert.cc +++ b/source/op/descrpt_se_a_ef_vert.cc @@ -49,7 +49,7 @@ class DescrptSeAEfVertOp : public OpKernel { } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/ewald_recp.cc b/source/op/ewald_recp.cc index 53c5fd1470..c9cc22b480 100644 --- a/source/op/ewald_recp.cc +++ b/source/op/ewald_recp.cc @@ -28,7 +28,7 @@ class EwaldRecpOp : public OpKernel { } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/gelu_multi_device.cc b/source/op/gelu_multi_device.cc index 767d864694..dc86ab6c8d 100644 --- a/source/op/gelu_multi_device.cc +++ b/source/op/gelu_multi_device.cc @@ -26,7 +26,7 @@ class GeluOp : public OpKernel { public : explicit GeluOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { @@ -77,7 +77,7 @@ class GeluGradOp : public OpKernel { public : explicit GeluGradOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { @@ -130,7 +130,7 @@ class GeluGradGradOp : public OpKernel { public : explicit GeluGradGradOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/map_aparam.cc b/source/op/map_aparam.cc index c0b65a28e6..cd70435f99 100644 --- a/source/op/map_aparam.cc +++ b/source/op/map_aparam.cc @@ -20,7 +20,7 @@ class MapAparamOp : public OpKernel { } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/neighbor_stat.cc b/source/op/neighbor_stat.cc index 52eeff8b7a..fad4617cc5 100644 --- a/source/op/neighbor_stat.cc +++ b/source/op/neighbor_stat.cc @@ -24,7 +24,7 @@ class NeighborStatOp : public OpKernel { } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/pair_tab.cc b/source/op/pair_tab.cc index eccf3e001f..2a22e17102 100644 --- a/source/op/pair_tab.cc +++ b/source/op/pair_tab.cc @@ -34,7 +34,7 @@ class PairTabOp : public OpKernel { } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/prod_env_mat_multi_device.cc b/source/op/prod_env_mat_multi_device.cc index 9e0a7abb06..69e08eaa5e 100644 --- a/source/op/prod_env_mat_multi_device.cc +++ b/source/op/prod_env_mat_multi_device.cc @@ -322,7 +322,7 @@ class ProdEnvMatAOp : public OpKernel { } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { @@ -589,7 +589,7 @@ class ProdEnvMatROp : public OpKernel { } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/prod_force.cc b/source/op/prod_force.cc index 46cfcfceda..a97fb6c575 100644 --- a/source/op/prod_force.cc +++ b/source/op/prod_force.cc @@ -26,7 +26,7 @@ class ProdForceOp : public OpKernel { } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/prod_force_grad.cc b/source/op/prod_force_grad.cc index d406e3e320..67423d7489 100644 --- a/source/op/prod_force_grad.cc +++ b/source/op/prod_force_grad.cc @@ -25,7 +25,7 @@ class ProdForceGradOp : public OpKernel } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/prod_force_grad_multi_device.cc b/source/op/prod_force_grad_multi_device.cc index 497b7945f0..533f6cbf14 100644 --- a/source/op/prod_force_grad_multi_device.cc +++ b/source/op/prod_force_grad_multi_device.cc @@ -31,7 +31,7 @@ class ProdForceSeAGradOp : public OpKernel { } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { @@ -143,7 +143,7 @@ class ProdForceSeRGradOp : public OpKernel explicit ProdForceSeRGradOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/prod_force_multi_device.cc b/source/op/prod_force_multi_device.cc index 94b859bca9..8df25636f6 100644 --- a/source/op/prod_force_multi_device.cc +++ b/source/op/prod_force_multi_device.cc @@ -25,7 +25,7 @@ class ProdForceSeAOp : public OpKernel { explicit ProdForceSeAOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/prod_force_se_a_grad.cc b/source/op/prod_force_se_a_grad.cc index 21eab7b2ce..84b2a7ed3b 100644 --- a/source/op/prod_force_se_a_grad.cc +++ b/source/op/prod_force_se_a_grad.cc @@ -25,7 +25,7 @@ class ProdForceSeAGradOp : public OpKernel } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/prod_force_se_r_grad.cc b/source/op/prod_force_se_r_grad.cc index 5ccdf2431c..e02f0c8750 100644 --- a/source/op/prod_force_se_r_grad.cc +++ b/source/op/prod_force_se_r_grad.cc @@ -20,7 +20,7 @@ class ProdForceSeRGradOp : public OpKernel } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/prod_virial.cc b/source/op/prod_virial.cc index 4dfc4d6824..a8df2bc848 100644 --- a/source/op/prod_virial.cc +++ b/source/op/prod_virial.cc @@ -28,7 +28,7 @@ class ProdVirialOp : public OpKernel { } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/prod_virial_grad.cc b/source/op/prod_virial_grad.cc index a644da56ed..33fa0348dc 100644 --- a/source/op/prod_virial_grad.cc +++ b/source/op/prod_virial_grad.cc @@ -26,7 +26,7 @@ class ProdVirialGradOp : public OpKernel } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/prod_virial_grad_multi_device.cc b/source/op/prod_virial_grad_multi_device.cc index e62cb5ac1a..9afd4462eb 100644 --- a/source/op/prod_virial_grad_multi_device.cc +++ b/source/op/prod_virial_grad_multi_device.cc @@ -34,7 +34,7 @@ class ProdVirialSeAGradOp : public OpKernel } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { @@ -157,7 +157,7 @@ class ProdVirialSeRGradOp : public OpKernel explicit ProdVirialSeRGradOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/prod_virial_multi_device.cc b/source/op/prod_virial_multi_device.cc index dbb8105edc..33c263ef84 100644 --- a/source/op/prod_virial_multi_device.cc +++ b/source/op/prod_virial_multi_device.cc @@ -28,7 +28,7 @@ class ProdVirialSeAOp : public OpKernel { public: explicit ProdVirialSeAOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { @@ -124,7 +124,7 @@ class ProdVirialSeROp : public OpKernel { public: explicit ProdVirialSeROp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/prod_virial_se_a_grad.cc b/source/op/prod_virial_se_a_grad.cc index 3df66beca3..00a88e0f76 100644 --- a/source/op/prod_virial_se_a_grad.cc +++ b/source/op/prod_virial_se_a_grad.cc @@ -26,7 +26,7 @@ class ProdVirialSeAGradOp : public OpKernel } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/prod_virial_se_r_grad.cc b/source/op/prod_virial_se_r_grad.cc index 53910c991e..7f9005abe4 100644 --- a/source/op/prod_virial_se_r_grad.cc +++ b/source/op/prod_virial_se_r_grad.cc @@ -21,7 +21,7 @@ class ProdVirialSeRGradOp : public OpKernel } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/soft_min.cc b/source/op/soft_min.cc index 524fdaf4ef..f7770ab58b 100644 --- a/source/op/soft_min.cc +++ b/source/op/soft_min.cc @@ -37,7 +37,7 @@ class SoftMinSwitchOp : public OpKernel { } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/soft_min_force.cc b/source/op/soft_min_force.cc index fbd9bed1e3..f10a48dc26 100644 --- a/source/op/soft_min_force.cc +++ b/source/op/soft_min_force.cc @@ -24,7 +24,7 @@ class SoftMinForceOp : public OpKernel { } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/soft_min_force_grad.cc b/source/op/soft_min_force_grad.cc index 2bb00b30b6..d5095d1005 100644 --- a/source/op/soft_min_force_grad.cc +++ b/source/op/soft_min_force_grad.cc @@ -24,7 +24,7 @@ class SoftMinForceGradOp : public OpKernel } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/soft_min_virial.cc b/source/op/soft_min_virial.cc index 2f9ab3a149..72d4a21e55 100644 --- a/source/op/soft_min_virial.cc +++ b/source/op/soft_min_virial.cc @@ -26,7 +26,7 @@ class SoftMinVirialOp : public OpKernel { } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/soft_min_virial_grad.cc b/source/op/soft_min_virial_grad.cc index fe6c944692..f92ac2a5c9 100644 --- a/source/op/soft_min_virial_grad.cc +++ b/source/op/soft_min_virial_grad.cc @@ -25,7 +25,7 @@ class SoftMinVirialGradOp : public OpKernel } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/tabulate_multi_device.cc b/source/op/tabulate_multi_device.cc index afbc736dc5..3d5765b843 100644 --- a/source/op/tabulate_multi_device.cc +++ b/source/op/tabulate_multi_device.cc @@ -28,7 +28,7 @@ class TabulateFusionOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("last_layer_size", &last_layer_size)); } void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { @@ -94,7 +94,7 @@ class TabulateFusionGradOp : public OpKernel { public: explicit TabulateFusionGradOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { diff --git a/source/op/unaggregated_grad.cc b/source/op/unaggregated_grad.cc index 1c719ed7b9..343a339a92 100644 --- a/source/op/unaggregated_grad.cc +++ b/source/op/unaggregated_grad.cc @@ -136,7 +136,7 @@ class UnaggregatedDyDxSOp : public OpKernel { explicit UnaggregatedDyDxSOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { @@ -173,7 +173,7 @@ class UnaggregatedDy2DxSOp : public OpKernel { explicit UnaggregatedDy2DxSOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { @@ -213,7 +213,7 @@ class UnaggregatedDyDxOp : public OpKernel { explicit UnaggregatedDyDxOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) { @@ -254,7 +254,7 @@ class UnaggregatedDy2DxOp : public OpKernel { explicit UnaggregatedDy2DxOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - deepmd::save_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); + deepmd::safe_compute(context, [this](OpKernelContext* context) {this->_Compute(context);}); } void _Compute(OpKernelContext* context) {