Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions source/op/tf/custom_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@
#pragma once
#include <iostream>
#include <string>
#include <utility>
#include <vector>

#include "tensorflow/core/public/version.h"

#if (TF_MAJOR_VERSION > 2) || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION >= 20)
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#endif

#include "device.h"
#include "neighbor_list.h"
#include "tensorflow/core/framework/op.h"
Expand All @@ -27,6 +35,25 @@ void safe_compute(OpKernelContext* context,
std::function<void(OpKernelContext*)> ff);
};

namespace deepmd {
namespace tf_compat {
#if (TF_MAJOR_VERSION > 2) || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION >= 20)
using Status = absl::Status;
#else
using Status = tensorflow::Status;
#endif

template <typename... Args>
inline Status InvalidArgument(Args&&... args) {
#if (TF_MAJOR_VERSION > 2) || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION >= 20)
return absl::InvalidArgumentError(absl::StrCat(std::forward<Args>(args)...));
#else
return tensorflow::errors::InvalidArgument(std::forward<Args>(args)...);
#endif
}
} // namespace tf_compat
} // namespace deepmd

template <typename FPTYPE>
void _prepare_coord_nlist_gpu(OpKernelContext* context,
Tensor* tensor_list,
Expand Down
68 changes: 39 additions & 29 deletions source/op/tf/descrpt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,23 @@ class DescrptOp : public OpKernel {

// set size of the sample
OP_REQUIRES(context, (coord_tensor.shape().dims() == 2),
errors::InvalidArgument("Dim of coord should be 2"));
deepmd::tf_compat::InvalidArgument("Dim of coord should be 2"));
OP_REQUIRES(context, (type_tensor.shape().dims() == 2),
errors::InvalidArgument("Dim of type should be 2"));
OP_REQUIRES(context, (natoms_tensor.shape().dims() == 1),
errors::InvalidArgument("Dim of natoms should be 1"));
deepmd::tf_compat::InvalidArgument("Dim of type should be 2"));
OP_REQUIRES(
context, (natoms_tensor.shape().dims() == 1),
deepmd::tf_compat::InvalidArgument("Dim of natoms should be 1"));
OP_REQUIRES(context, (box_tensor.shape().dims() == 2),
errors::InvalidArgument("Dim of box should be 2"));
deepmd::tf_compat::InvalidArgument("Dim of box should be 2"));
OP_REQUIRES(context, (mesh_tensor.shape().dims() == 1),
errors::InvalidArgument("Dim of mesh should be 1"));
deepmd::tf_compat::InvalidArgument("Dim of mesh should be 1"));
OP_REQUIRES(context, (avg_tensor.shape().dims() == 2),
errors::InvalidArgument("Dim of avg should be 2"));
deepmd::tf_compat::InvalidArgument("Dim of avg should be 2"));
OP_REQUIRES(context, (std_tensor.shape().dims() == 2),
errors::InvalidArgument("Dim of std should be 2"));
deepmd::tf_compat::InvalidArgument("Dim of std should be 2"));

OP_REQUIRES(context, (natoms_tensor.shape().dim_size(0) >= 3),
errors::InvalidArgument(
deepmd::tf_compat::InvalidArgument(
"number of atoms should be larger than (or equal to) 3"));
auto natoms = natoms_tensor.flat<int>();
int nloc = natoms(0);
Expand All @@ -91,25 +92,34 @@ class DescrptOp : public OpKernel {
int nsamples = coord_tensor.shape().dim_size(0);

// check the sizes
OP_REQUIRES(context, (nsamples == type_tensor.shape().dim_size(0)),
errors::InvalidArgument("number of samples should match"));
OP_REQUIRES(context, (nsamples == box_tensor.shape().dim_size(0)),
errors::InvalidArgument("number of samples should match"));
OP_REQUIRES(context, (ntypes == avg_tensor.shape().dim_size(0)),
errors::InvalidArgument("number of avg should be ntype"));
OP_REQUIRES(context, (ntypes == std_tensor.shape().dim_size(0)),
errors::InvalidArgument("number of std should be ntype"));
OP_REQUIRES(
context, (nsamples == type_tensor.shape().dim_size(0)),
deepmd::tf_compat::InvalidArgument("number of samples should match"));
OP_REQUIRES(
context, (nsamples == box_tensor.shape().dim_size(0)),
deepmd::tf_compat::InvalidArgument("number of samples should match"));
OP_REQUIRES(
context, (ntypes == avg_tensor.shape().dim_size(0)),
deepmd::tf_compat::InvalidArgument("number of avg should be ntype"));
OP_REQUIRES(
context, (ntypes == std_tensor.shape().dim_size(0)),
deepmd::tf_compat::InvalidArgument("number of std should be ntype"));

OP_REQUIRES(context, (nall * 3 == coord_tensor.shape().dim_size(1)),
errors::InvalidArgument("number of atoms should match"));
OP_REQUIRES(context, (nall == type_tensor.shape().dim_size(1)),
errors::InvalidArgument("number of atoms should match"));
OP_REQUIRES(context, (9 == box_tensor.shape().dim_size(1)),
errors::InvalidArgument("number of box should be 9"));
OP_REQUIRES(context, (ndescrpt == avg_tensor.shape().dim_size(1)),
errors::InvalidArgument("number of avg should be ndescrpt"));
OP_REQUIRES(context, (ndescrpt == std_tensor.shape().dim_size(1)),
errors::InvalidArgument("number of std should be ndescrpt"));
OP_REQUIRES(
context, (nall * 3 == coord_tensor.shape().dim_size(1)),
deepmd::tf_compat::InvalidArgument("number of atoms should match"));
OP_REQUIRES(
context, (nall == type_tensor.shape().dim_size(1)),
deepmd::tf_compat::InvalidArgument("number of atoms should match"));
OP_REQUIRES(
context, (9 == box_tensor.shape().dim_size(1)),
deepmd::tf_compat::InvalidArgument("number of box should be 9"));
OP_REQUIRES(
context, (ndescrpt == avg_tensor.shape().dim_size(1)),
deepmd::tf_compat::InvalidArgument("number of avg should be ndescrpt"));
OP_REQUIRES(
context, (ndescrpt == std_tensor.shape().dim_size(1)),
deepmd::tf_compat::InvalidArgument("number of std should be ndescrpt"));

int nei_mode = 0;
if (mesh_tensor.shape().dim_size(0) == 16) {
Expand Down Expand Up @@ -201,10 +211,10 @@ class DescrptOp : public OpKernel {
// }
// int ntypes = max_type_v + 1;
OP_REQUIRES(context, (ntypes == int(sel_a.size())),
errors::InvalidArgument(
deepmd::tf_compat::InvalidArgument(
"number of types should match the length of sel array"));
OP_REQUIRES(context, (ntypes == int(sel_r.size())),
errors::InvalidArgument(
deepmd::tf_compat::InvalidArgument(
"number of types should match the length of sel array"));

for (int kk = 0; kk < nsamples; ++kk) {
Expand Down
81 changes: 46 additions & 35 deletions source/op/tf/descrpt_se_a_ef.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,32 +69,33 @@ class DescrptSeAEfOp : public OpKernel {

// set size of the sample
OP_REQUIRES(context, (coord_tensor.shape().dims() == 2),
errors::InvalidArgument("Dim of coord should be 2"));
deepmd::tf_compat::InvalidArgument("Dim of coord should be 2"));
OP_REQUIRES(context, (type_tensor.shape().dims() == 2),
errors::InvalidArgument("Dim of type should be 2"));
OP_REQUIRES(context, (natoms_tensor.shape().dims() == 1),
errors::InvalidArgument("Dim of natoms should be 1"));
deepmd::tf_compat::InvalidArgument("Dim of type should be 2"));
OP_REQUIRES(
context, (natoms_tensor.shape().dims() == 1),
deepmd::tf_compat::InvalidArgument("Dim of natoms should be 1"));
OP_REQUIRES(context, (box_tensor.shape().dims() == 2),
errors::InvalidArgument("Dim of box should be 2"));
deepmd::tf_compat::InvalidArgument("Dim of box should be 2"));
OP_REQUIRES(context, (mesh_tensor.shape().dims() == 1),
errors::InvalidArgument("Dim of mesh should be 1"));
deepmd::tf_compat::InvalidArgument("Dim of mesh should be 1"));
OP_REQUIRES(context, (ef_tensor.shape().dims() == 2),
errors::InvalidArgument("Dim of ef should be 2"));
deepmd::tf_compat::InvalidArgument("Dim of ef should be 2"));
OP_REQUIRES(context, (avg_tensor.shape().dims() == 2),
errors::InvalidArgument("Dim of avg should be 2"));
deepmd::tf_compat::InvalidArgument("Dim of avg should be 2"));
OP_REQUIRES(context, (std_tensor.shape().dims() == 2),
errors::InvalidArgument("Dim of std should be 2"));
deepmd::tf_compat::InvalidArgument("Dim of std should be 2"));
OP_REQUIRES(
context, (fill_nei_a),
errors::InvalidArgument(
deepmd::tf_compat::InvalidArgument(
"Rotational free descriptor only support the case rcut_a < 0"));
OP_REQUIRES(context, (sec_r.back() == 0),
errors::InvalidArgument(
deepmd::tf_compat::InvalidArgument(
"Rotational free descriptor only support all-angular "
"information: sel_r should be all zero."));

OP_REQUIRES(context, (natoms_tensor.shape().dim_size(0) >= 3),
errors::InvalidArgument(
deepmd::tf_compat::InvalidArgument(
"number of atoms should be larger than (or equal to) 3"));
auto natoms = natoms_tensor.flat<int>();
int nloc = natoms(0);
Expand All @@ -103,29 +104,39 @@ class DescrptSeAEfOp : public OpKernel {
int nsamples = coord_tensor.shape().dim_size(0);

// check the sizes
OP_REQUIRES(context, (nsamples == type_tensor.shape().dim_size(0)),
errors::InvalidArgument("number of samples should match"));
OP_REQUIRES(context, (nsamples == box_tensor.shape().dim_size(0)),
errors::InvalidArgument("number of samples should match"));
OP_REQUIRES(context, (nsamples == ef_tensor.shape().dim_size(0)),
errors::InvalidArgument("number of samples should match"));
OP_REQUIRES(context, (ntypes == avg_tensor.shape().dim_size(0)),
errors::InvalidArgument("number of avg should be ntype"));
OP_REQUIRES(context, (ntypes == std_tensor.shape().dim_size(0)),
errors::InvalidArgument("number of std should be ntype"));
OP_REQUIRES(
context, (nsamples == type_tensor.shape().dim_size(0)),
deepmd::tf_compat::InvalidArgument("number of samples should match"));
OP_REQUIRES(
context, (nsamples == box_tensor.shape().dim_size(0)),
deepmd::tf_compat::InvalidArgument("number of samples should match"));
OP_REQUIRES(
context, (nsamples == ef_tensor.shape().dim_size(0)),
deepmd::tf_compat::InvalidArgument("number of samples should match"));
OP_REQUIRES(
context, (ntypes == avg_tensor.shape().dim_size(0)),
deepmd::tf_compat::InvalidArgument("number of avg should be ntype"));
OP_REQUIRES(
context, (ntypes == std_tensor.shape().dim_size(0)),
deepmd::tf_compat::InvalidArgument("number of std should be ntype"));

OP_REQUIRES(context, (nall * 3 == coord_tensor.shape().dim_size(1)),
errors::InvalidArgument("number of atoms should match"));
OP_REQUIRES(context, (nall == type_tensor.shape().dim_size(1)),
errors::InvalidArgument("number of atoms should match"));
OP_REQUIRES(context, (9 == box_tensor.shape().dim_size(1)),
errors::InvalidArgument("number of box should be 9"));
OP_REQUIRES(
context, (nall * 3 == coord_tensor.shape().dim_size(1)),
deepmd::tf_compat::InvalidArgument("number of atoms should match"));
OP_REQUIRES(
context, (nall == type_tensor.shape().dim_size(1)),
deepmd::tf_compat::InvalidArgument("number of atoms should match"));
OP_REQUIRES(
context, (9 == box_tensor.shape().dim_size(1)),
deepmd::tf_compat::InvalidArgument("number of box should be 9"));
OP_REQUIRES(context, (nloc * 3 == ef_tensor.shape().dim_size(1)),
errors::InvalidArgument("number of ef should be 3"));
OP_REQUIRES(context, (ndescrpt == avg_tensor.shape().dim_size(1)),
errors::InvalidArgument("number of avg should be ndescrpt"));
OP_REQUIRES(context, (ndescrpt == std_tensor.shape().dim_size(1)),
errors::InvalidArgument("number of std should be ndescrpt"));
deepmd::tf_compat::InvalidArgument("number of ef should be 3"));
OP_REQUIRES(
context, (ndescrpt == avg_tensor.shape().dim_size(1)),
deepmd::tf_compat::InvalidArgument("number of avg should be ndescrpt"));
OP_REQUIRES(
context, (ndescrpt == std_tensor.shape().dim_size(1)),
deepmd::tf_compat::InvalidArgument("number of std should be ndescrpt"));

int nei_mode = 0;
if (mesh_tensor.shape().dim_size(0) == 16) {
Expand Down Expand Up @@ -208,10 +219,10 @@ class DescrptSeAEfOp : public OpKernel {
// }
// int ntypes = max_type_v + 1;
OP_REQUIRES(context, (ntypes == int(sel_a.size())),
errors::InvalidArgument(
deepmd::tf_compat::InvalidArgument(
"number of types should match the length of sel array"));
OP_REQUIRES(context, (ntypes == int(sel_r.size())),
errors::InvalidArgument(
deepmd::tf_compat::InvalidArgument(
"number of types should match the length of sel array"));

for (int kk = 0; kk < nsamples; ++kk) {
Expand Down
81 changes: 46 additions & 35 deletions source/op/tf/descrpt_se_a_ef_para.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,32 +69,33 @@ class DescrptSeAEfParaOp : public OpKernel {

// set size of the sample
OP_REQUIRES(context, (coord_tensor.shape().dims() == 2),
errors::InvalidArgument("Dim of coord should be 2"));
deepmd::tf_compat::InvalidArgument("Dim of coord should be 2"));
OP_REQUIRES(context, (type_tensor.shape().dims() == 2),
errors::InvalidArgument("Dim of type should be 2"));
OP_REQUIRES(context, (natoms_tensor.shape().dims() == 1),
errors::InvalidArgument("Dim of natoms should be 1"));
deepmd::tf_compat::InvalidArgument("Dim of type should be 2"));
OP_REQUIRES(
context, (natoms_tensor.shape().dims() == 1),
deepmd::tf_compat::InvalidArgument("Dim of natoms should be 1"));
OP_REQUIRES(context, (box_tensor.shape().dims() == 2),
errors::InvalidArgument("Dim of box should be 2"));
deepmd::tf_compat::InvalidArgument("Dim of box should be 2"));
OP_REQUIRES(context, (mesh_tensor.shape().dims() == 1),
errors::InvalidArgument("Dim of mesh should be 1"));
deepmd::tf_compat::InvalidArgument("Dim of mesh should be 1"));
OP_REQUIRES(context, (ef_tensor.shape().dims() == 2),
errors::InvalidArgument("Dim of ef should be 2"));
deepmd::tf_compat::InvalidArgument("Dim of ef should be 2"));
OP_REQUIRES(context, (avg_tensor.shape().dims() == 2),
errors::InvalidArgument("Dim of avg should be 2"));
deepmd::tf_compat::InvalidArgument("Dim of avg should be 2"));
OP_REQUIRES(context, (std_tensor.shape().dims() == 2),
errors::InvalidArgument("Dim of std should be 2"));
deepmd::tf_compat::InvalidArgument("Dim of std should be 2"));
OP_REQUIRES(
context, (fill_nei_a),
errors::InvalidArgument(
deepmd::tf_compat::InvalidArgument(
"Rotational free descriptor only support the case rcut_a < 0"));
OP_REQUIRES(context, (sec_r.back() == 0),
errors::InvalidArgument(
deepmd::tf_compat::InvalidArgument(
"Rotational free descriptor only support all-angular "
"information: sel_r should be all zero."));

OP_REQUIRES(context, (natoms_tensor.shape().dim_size(0) >= 3),
errors::InvalidArgument(
deepmd::tf_compat::InvalidArgument(
"number of atoms should be larger than (or equal to) 3"));
auto natoms = natoms_tensor.flat<int>();
int nloc = natoms(0);
Expand All @@ -103,29 +104,39 @@ class DescrptSeAEfParaOp : public OpKernel {
int nsamples = coord_tensor.shape().dim_size(0);

// check the sizes
OP_REQUIRES(context, (nsamples == type_tensor.shape().dim_size(0)),
errors::InvalidArgument("number of samples should match"));
OP_REQUIRES(context, (nsamples == box_tensor.shape().dim_size(0)),
errors::InvalidArgument("number of samples should match"));
OP_REQUIRES(context, (nsamples == ef_tensor.shape().dim_size(0)),
errors::InvalidArgument("number of samples should match"));
OP_REQUIRES(context, (ntypes == avg_tensor.shape().dim_size(0)),
errors::InvalidArgument("number of avg should be ntype"));
OP_REQUIRES(context, (ntypes == std_tensor.shape().dim_size(0)),
errors::InvalidArgument("number of std should be ntype"));
OP_REQUIRES(
context, (nsamples == type_tensor.shape().dim_size(0)),
deepmd::tf_compat::InvalidArgument("number of samples should match"));
OP_REQUIRES(
context, (nsamples == box_tensor.shape().dim_size(0)),
deepmd::tf_compat::InvalidArgument("number of samples should match"));
OP_REQUIRES(
context, (nsamples == ef_tensor.shape().dim_size(0)),
deepmd::tf_compat::InvalidArgument("number of samples should match"));
OP_REQUIRES(
context, (ntypes == avg_tensor.shape().dim_size(0)),
deepmd::tf_compat::InvalidArgument("number of avg should be ntype"));
OP_REQUIRES(
context, (ntypes == std_tensor.shape().dim_size(0)),
deepmd::tf_compat::InvalidArgument("number of std should be ntype"));

OP_REQUIRES(context, (nall * 3 == coord_tensor.shape().dim_size(1)),
errors::InvalidArgument("number of atoms should match"));
OP_REQUIRES(context, (nall == type_tensor.shape().dim_size(1)),
errors::InvalidArgument("number of atoms should match"));
OP_REQUIRES(context, (9 == box_tensor.shape().dim_size(1)),
errors::InvalidArgument("number of box should be 9"));
OP_REQUIRES(
context, (nall * 3 == coord_tensor.shape().dim_size(1)),
deepmd::tf_compat::InvalidArgument("number of atoms should match"));
OP_REQUIRES(
context, (nall == type_tensor.shape().dim_size(1)),
deepmd::tf_compat::InvalidArgument("number of atoms should match"));
OP_REQUIRES(
context, (9 == box_tensor.shape().dim_size(1)),
deepmd::tf_compat::InvalidArgument("number of box should be 9"));
OP_REQUIRES(context, (nloc * 3 == ef_tensor.shape().dim_size(1)),
errors::InvalidArgument("number of ef should be 3"));
OP_REQUIRES(context, (ndescrpt == avg_tensor.shape().dim_size(1)),
errors::InvalidArgument("number of avg should be ndescrpt"));
OP_REQUIRES(context, (ndescrpt == std_tensor.shape().dim_size(1)),
errors::InvalidArgument("number of std should be ndescrpt"));
deepmd::tf_compat::InvalidArgument("number of ef should be 3"));
OP_REQUIRES(
context, (ndescrpt == avg_tensor.shape().dim_size(1)),
deepmd::tf_compat::InvalidArgument("number of avg should be ndescrpt"));
OP_REQUIRES(
context, (ndescrpt == std_tensor.shape().dim_size(1)),
deepmd::tf_compat::InvalidArgument("number of std should be ndescrpt"));

int nei_mode = 0;
if (mesh_tensor.shape().dim_size(0) == 16) {
Expand Down Expand Up @@ -208,10 +219,10 @@ class DescrptSeAEfParaOp : public OpKernel {
// }
// int ntypes = max_type_v + 1;
OP_REQUIRES(context, (ntypes == int(sel_a.size())),
errors::InvalidArgument(
deepmd::tf_compat::InvalidArgument(
"number of types should match the length of sel array"));
OP_REQUIRES(context, (ntypes == int(sel_r.size())),
errors::InvalidArgument(
deepmd::tf_compat::InvalidArgument(
"number of types should match the length of sel array"));

for (int kk = 0; kk < nsamples; ++kk) {
Expand Down
Loading
Loading