Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions source/api_cc/src/DataModifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ compute (std::vector<VALUETYPE> & dfcorr_,
if (nloc_real == 0){
dfcorr_.resize(nall * 3);
dvcorr_.resize(9);
fill(dfcorr_.begin(), dfcorr_.end(), 0.0);
fill(dvcorr_.begin(), dvcorr_.end(), 0.0);
fill(dfcorr_.begin(), dfcorr_.end(), (VALUETYPE)0.0);
fill(dvcorr_.begin(), dvcorr_.end(), (VALUETYPE)0.0);
return;
}
// resize to nall_real
Expand Down Expand Up @@ -223,7 +223,7 @@ compute (std::vector<VALUETYPE> & dfcorr_,
assert(dfcorr_1.size() == nall_real * 3);
// resize to all and clear
std::vector<VALUETYPE> dfcorr_2(nall*3);
fill(dfcorr_2.begin(), dfcorr_2.end(), 0.0);
fill(dfcorr_2.begin(), dfcorr_2.end(), (VALUETYPE)0.0);
// back map to original position
for (int ii = 0; ii < nall_real; ++ii){
for (int dd = 0; dd < 3; ++dd){
Expand Down
52 changes: 26 additions & 26 deletions source/api_cc/src/DeepPot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ run_model (ENERGYTYPE & dener,
// no backward map needed
// dforce of size nall * 3
dforce_.resize(nall * 3);
fill(dforce_.begin(), dforce_.end(), 0.0);
fill(dforce_.begin(), dforce_.end(), (VALUETYPE)0.0);
// dvirial of size 9
dvirial.resize(9);
fill(dvirial.begin(), dvirial.end(), 0.0);
fill(dvirial.begin(), dvirial.end(), (VALUETYPE)0.0);
return;
}

Expand All @@ -62,17 +62,17 @@ run_model (ENERGYTYPE & dener,
dforce[ii] = of(ii);
}
// set dvirial to zero, prevent input vector is not zero (#1123)
std::fill(dvirial.begin(), dvirial.end(), 0.);
std::fill(dvirial.begin(), dvirial.end(), (VALUETYPE)0.);
for (int ii = 0; ii < nall; ++ii) {
dvirial[0] += 1.0 * oav(9*ii+0);
dvirial[1] += 1.0 * oav(9*ii+1);
dvirial[2] += 1.0 * oav(9*ii+2);
dvirial[3] += 1.0 * oav(9*ii+3);
dvirial[4] += 1.0 * oav(9*ii+4);
dvirial[5] += 1.0 * oav(9*ii+5);
dvirial[6] += 1.0 * oav(9*ii+6);
dvirial[7] += 1.0 * oav(9*ii+7);
dvirial[8] += 1.0 * oav(9*ii+8);
dvirial[0] += (VALUETYPE)1.0 * oav(9*ii+0);
dvirial[1] += (VALUETYPE)1.0 * oav(9*ii+1);
dvirial[2] += (VALUETYPE)1.0 * oav(9*ii+2);
dvirial[3] += (VALUETYPE)1.0 * oav(9*ii+3);
dvirial[4] += (VALUETYPE)1.0 * oav(9*ii+4);
dvirial[5] += (VALUETYPE)1.0 * oav(9*ii+5);
dvirial[6] += (VALUETYPE)1.0 * oav(9*ii+6);
dvirial[7] += (VALUETYPE)1.0 * oav(9*ii+7);
dvirial[8] += (VALUETYPE)1.0 * oav(9*ii+8);
}
dforce_ = dforce;
atommap.backward (dforce_.begin(), dforce.begin(), 3);
Expand All @@ -95,16 +95,16 @@ static void run_model (ENERGYTYPE & dener,
// no backward map needed
// dforce of size nall * 3
dforce_.resize(nall * 3);
fill(dforce_.begin(), dforce_.end(), 0.0);
fill(dforce_.begin(), dforce_.end(), (VALUETYPE)0.0);
// dvirial of size 9
dvirial.resize(9);
fill(dvirial.begin(), dvirial.end(), 0.0);
fill(dvirial.begin(), dvirial.end(), (VALUETYPE)0.0);
// datom_energy_ of size nall
datom_energy_.resize(nall);
fill(datom_energy_.begin(), datom_energy_.end(), 0.0);
fill(datom_energy_.begin(), datom_energy_.end(), (VALUETYPE)0.0);
// datom_virial_ of size nall * 9
datom_virial_.resize(nall * 9);
fill(datom_virial_.begin(), datom_virial_.end(), 0.0);
fill(datom_virial_.begin(), datom_virial_.end(), (VALUETYPE)0.0);
return;
}
std::vector<Tensor> output_tensors;
Expand Down Expand Up @@ -139,17 +139,17 @@ static void run_model (ENERGYTYPE & dener,
datom_virial[ii] = oav(ii);
}
// set dvirial to zero, prevent input vector is not zero (#1123)
std::fill(dvirial.begin(), dvirial.end(), 0.);
std::fill(dvirial.begin(), dvirial.end(), (VALUETYPE)0.);
for (int ii = 0; ii < nall; ++ii) {
dvirial[0] += 1.0 * datom_virial[9*ii+0];
dvirial[1] += 1.0 * datom_virial[9*ii+1];
dvirial[2] += 1.0 * datom_virial[9*ii+2];
dvirial[3] += 1.0 * datom_virial[9*ii+3];
dvirial[4] += 1.0 * datom_virial[9*ii+4];
dvirial[5] += 1.0 * datom_virial[9*ii+5];
dvirial[6] += 1.0 * datom_virial[9*ii+6];
dvirial[7] += 1.0 * datom_virial[9*ii+7];
dvirial[8] += 1.0 * datom_virial[9*ii+8];
dvirial[0] += (VALUETYPE)1.0 * datom_virial[9*ii+0];
dvirial[1] += (VALUETYPE)1.0 * datom_virial[9*ii+1];
dvirial[2] += (VALUETYPE)1.0 * datom_virial[9*ii+2];
dvirial[3] += (VALUETYPE)1.0 * datom_virial[9*ii+3];
dvirial[4] += (VALUETYPE)1.0 * datom_virial[9*ii+4];
dvirial[5] += (VALUETYPE)1.0 * datom_virial[9*ii+5];
dvirial[6] += (VALUETYPE)1.0 * datom_virial[9*ii+6];
dvirial[7] += (VALUETYPE)1.0 * datom_virial[9*ii+7];
dvirial[8] += (VALUETYPE)1.0 * datom_virial[9*ii+8];
}
dforce_ = dforce;
datom_energy_ = datom_energy;
Expand Down
2 changes: 1 addition & 1 deletion source/lib/include/utilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ template <>
inline float
invsqrt<float> (const float x)
{
return 1./sqrtf (x);
return 1.f/sqrtf (x);
}

}
4 changes: 2 additions & 2 deletions source/lib/src/coord.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ normalize_coord_cpu(
FPTYPE ri[3];
convert_to_inter_cpu(ri, region, coord+3*ii);
for(int dd = 0; dd < 3; ++dd){
ri[dd] = fmod(ri[dd], 1.);
if (ri[dd] < 0.) ri[dd] += 1.;
ri[dd] = fmod(ri[dd], (FPTYPE)1.);
if (ri[dd] < (FPTYPE)0.) ri[dd] += (FPTYPE)1.;
}
convert_to_phys_cpu(coord+3*ii, region, ri);
}
Expand Down
12 changes: 8 additions & 4 deletions source/lib/src/cuda/coord.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ __device__ inline int compute_pbc_shift(
return shift;
}

__device__ inline double _fmod(double x, double y) {return fmod(x, y);}
__device__ inline float _fmod(float x, float y) {return fmodf(x, y);}


template<typename FPTYPE>
__global__ void normalize_one(
FPTYPE *out_c,
Expand All @@ -64,8 +68,8 @@ __global__ void normalize_one(
FPTYPE inter[3];
phys2Inter(inter,out_c+idy*3,rec_boxt);
for (int dd = 0; dd < 3; ++dd) {
inter[dd]=(FPTYPE)fmod((double)inter[dd], 1.);
if (inter[dd] < 0.) inter[dd] += 1.;
inter[dd]=_fmod(inter[dd], (FPTYPE)1.);
if (inter[dd] < (FPTYPE)0.) inter[dd] += (FPTYPE)1.;
}
inter2Phys(out_c+idy*3,inter,boxt);
}
Expand Down Expand Up @@ -93,7 +97,7 @@ __global__ void _fill_idx_cellmap(
ext_ncell[dd] = ext_end[dd] - ext_stt[dd];
global_grid[dd] = nat_end[dd] - nat_stt[dd];
idx_orig_shift[dd] = nat_stt[dd] - ext_stt[dd];
cell_size[dd] = 1./global_grid[dd];
cell_size[dd] = (FPTYPE)1./global_grid[dd];
nat_orig[dd] = nat_stt[dd] * cell_size[dd];
}
if (idy<nloc)
Expand All @@ -104,7 +108,7 @@ __global__ void _fill_idx_cellmap(
phys2Inter(inter,in_c+idy*3,rec_boxt);
for (int dd = 0; dd < 3; ++dd){
idx_noshift[dd] = (inter[dd] - nat_orig[dd]) / cell_size[dd];
if (inter[dd] - nat_orig[dd] < 0.) idx_noshift[dd] --;
if (inter[dd] - nat_orig[dd] < (FPTYPE)0.) idx_noshift[dd] --;
if (idx_noshift[dd] < nat_stt[dd])
{
idx_noshift[dd] = nat_stt[dd];
Expand Down
15 changes: 9 additions & 6 deletions source/lib/src/cuda/gelu.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#include "gelu.h"
#include "device.h"

__device__ inline double _tanh(double x) {return tanh(x);}
__device__ inline float _tanh(float x) {return tanhf(x);}

template <typename FPTYPE>
__global__ void gelu(
FPTYPE * out,
Expand All @@ -11,7 +14,7 @@ __global__ void gelu(
if (idx >= size) {
return;
}
out[idx] = xx[idx] * 0.5 * (1.0 + tanh(SQRT_2_PI * (xx[idx] + 0.044715 * xx[idx] * xx[idx] *xx[idx])));
out[idx] = xx[idx] * (FPTYPE)0.5 * ((FPTYPE)1.0 + _tanh((FPTYPE)SQRT_2_PI * (xx[idx] + (FPTYPE)0.044715 * xx[idx] * xx[idx] *xx[idx])));
}

template <typename FPTYPE>
Expand All @@ -26,8 +29,8 @@ __global__ void gelu_grad(
return;
}
// out[idx] = xx[idx] * 0.5 * (1.0 + tanh(SQRT_2_PI * (xx[idx] + 0.044715 * xx[idx] * xx[idx] *xx[idx])));
const FPTYPE var = tanh(SQRT_2_PI * (xx[idx] + 0.044715 * xx[idx] * xx[idx] *xx[idx]));
out[idx] = dy[idx] * (0.5 * SQRT_2_PI * xx[idx] * (1 - var * var) * (0.134145 * xx[idx] * xx[idx] + 1) + 0.5 * var + 0.5);
const FPTYPE var = _tanh((FPTYPE)SQRT_2_PI * (xx[idx] + (FPTYPE)0.044715 * xx[idx] * xx[idx] *xx[idx]));
out[idx] = dy[idx] * ((FPTYPE)0.5 * SQRT_2_PI * xx[idx] * ((FPTYPE)1. - var * var) * ((FPTYPE)0.134145 * xx[idx] * xx[idx] + 1) + (FPTYPE)0.5 * var + (FPTYPE)0.5);
}

template <typename FPTYPE>
Expand All @@ -43,9 +46,9 @@ __global__ void gelu_grad_grad(
return;
}
// out[idx] = xx[idx] * 0.5 * (1.0 + tanh(SQRT_2_PI * (xx[idx] + 0.044715 * xx[idx] * xx[idx] *xx[idx])));
const FPTYPE var1 = tanh(SQRT_2_PI * (xx[idx] + 0.044715 * xx[idx] * xx[idx] *xx[idx]));
const FPTYPE var2 = SQRT_2_PI * (1 - var1 * var1) * (0.134145 * xx[idx] * xx[idx] + 1);
out[idx] = dy[idx] * dy_2[idx] * (0.134145 * SQRT_2_PI * xx[idx] * xx[idx] * (1 - var1 * var1) - SQRT_2_PI * xx[idx] * var2 * (0.134145 * xx[idx] * xx[idx] + 1) * var1 + var2);
const FPTYPE var1 = _tanh((FPTYPE)SQRT_2_PI * (xx[idx] + (FPTYPE)0.044715 * xx[idx] * xx[idx] *xx[idx]));
const FPTYPE var2 = (FPTYPE)SQRT_2_PI * ((FPTYPE)1. - var1 * var1) * ((FPTYPE)0.134145 * xx[idx] * xx[idx] + (FPTYPE)1.);
out[idx] = dy[idx] * dy_2[idx] * ((FPTYPE)0.134145 * (FPTYPE)SQRT_2_PI * xx[idx] * xx[idx] * ((FPTYPE)1. - var1 * var1) - (FPTYPE)SQRT_2_PI * xx[idx] * var2 * ((FPTYPE)0.134145 * xx[idx] * xx[idx] + (FPTYPE)1.) * var1 + var2);
}

namespace deepmd {
Expand Down
47 changes: 25 additions & 22 deletions source/lib/src/cuda/prod_env_mat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
#include <cub/block/block_store.cuh>
#include <cub/block/block_radix_sort.cuh>

__device__ inline double _sqrt(double x) {return sqrt(x);}
__device__ inline float _sqrt(float x) {return sqrtf(x);}

// common part of prod_env_mat
template <
typename Key,
Expand Down Expand Up @@ -57,18 +60,18 @@ __device__ inline void spline5_switch(
const float & rmax)
{
if (xx < rmin) {
dd = 0;
vv = 1;
dd = (FPTYPE)0.;
vv = (FPTYPE)1.;
}
else if (xx < rmax) {
FPTYPE uu = (xx - rmin) / (rmax - rmin) ;
FPTYPE du = 1. / (rmax - rmin) ;
vv = uu*uu*uu * (-6 * uu*uu + 15 * uu - 10) + 1;
dd = ( 3 * uu*uu * (-6 * uu*uu + 15 * uu - 10) + uu*uu*uu * (-12 * uu + 15) ) * du;
FPTYPE du = (FPTYPE)1. / (rmax - rmin) ;
vv = uu*uu*uu * ((FPTYPE)-6. * uu*uu + (FPTYPE)15. * uu - (FPTYPE)10.) + (FPTYPE)1.;
dd = ( (FPTYPE)3. * uu*uu * ((FPTYPE)-6. * uu*uu + (FPTYPE)15. * uu - (FPTYPE)10.) + uu*uu*uu * ((FPTYPE)-12. * uu + (FPTYPE)15.) ) * du;
}
else {
dd = 0;
vv = 0;
dd = (FPTYPE)0.;
vv = (FPTYPE)0.;
}
}

Expand All @@ -82,7 +85,7 @@ __device__ inline uint_64 encoding_nbor_info(
// the type of nbor atom must be smaller than 128
// the distance of center atom between nbor atom must be smaller than 128
// the index of nbor atom(including ghost region) must be smaller than 16777216(1 << 24)
if(type >= 128 || dist >= 128.0 || index >= (1 << 24)) {
if(type >= 128 || dist >= (FPTYPE)128.0 || index >= (1 << 24)) {
asm("trap;");
}
return ((uint_64)type << 57) + (uint_64)((double)dist * ((uint_64)1 << 50)) / (1 << 24) * (1 << 24) + index;
Expand Down Expand Up @@ -138,7 +141,7 @@ __global__ void format_nlist_fill_a(
for (int dd = 0; dd < 3; dd++) {
diff[dd] = coord[j_idx * 3 + dd] - coord[idx * 3 + dd];
}
FPTYPE rr = sqrt(dev_dot(diff, diff));
FPTYPE rr = _sqrt(dev_dot(diff, diff));
if (rr <= rcut) {
key_in[idy] = encoding_nbor_info(type[j_idx], rr, j_idx);
}
Expand Down Expand Up @@ -345,32 +348,32 @@ __global__ void compute_env_mat_a(
}
// const FPTYPE * rr = &row_rij[ii * 3];
FPTYPE nr2 = dev_dot(rr, rr);
FPTYPE inr = 1./sqrt(nr2);
FPTYPE inr = (FPTYPE)1./_sqrt(nr2);
FPTYPE nr = nr2 * inr;
FPTYPE inr2 = inr * inr;
FPTYPE inr4 = inr2 * inr2;
FPTYPE inr3 = inr4 * nr;
FPTYPE sw, dsw;
spline5_switch(sw, dsw, nr, rmin, rmax);
dd[0] = (1./nr) ;//* sw;
dd[0] = ((FPTYPE)1./nr) ;//* sw;
dd[1] = (rr[0] / nr2) ;//* sw;
dd[2] = (rr[1] / nr2) ;//* sw;
dd[3] = (rr[2] / nr2) ;//* sw;
vv[0] = (rr[0] * inr3 * sw - dd[0] * dsw * rr[0] * inr); // avg[type[(idx_deriv + 0) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 0) % (ndescrpt * 3)) / 3];
vv[1] = (rr[1] * inr3 * sw - dd[0] * dsw * rr[1] * inr); // avg[type[(idx_deriv + 1) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 1) % (ndescrpt * 3)) / 3];
vv[2] = (rr[2] * inr3 * sw - dd[0] * dsw * rr[2] * inr); // avg[type[(idx_deriv + 2) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 2) % (ndescrpt * 3)) / 3];
// ****deriv of component x/r2
vv[3] = ((2. * rr[0] * rr[0] * inr4 - inr2) * sw - dd[1] * dsw * rr[0] * inr); // avg[type[(idx_deriv + 3) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 3) % (ndescrpt * 3)) / 3];
vv[4] = ((2. * rr[0] * rr[1] * inr4 ) * sw - dd[1] * dsw * rr[1] * inr); // avg[type[(idx_deriv + 4) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 4) % (ndescrpt * 3)) / 3];
vv[5] = ((2. * rr[0] * rr[2] * inr4 ) * sw - dd[1] * dsw * rr[2] * inr); // avg[type[(idx_deriv + 5) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 5) % (ndescrpt * 3)) / 3];
vv[3] = (((FPTYPE)2. * rr[0] * rr[0] * inr4 - inr2) * sw - dd[1] * dsw * rr[0] * inr); // avg[type[(idx_deriv + 3) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 3) % (ndescrpt * 3)) / 3];
vv[4] = (((FPTYPE)2. * rr[0] * rr[1] * inr4 ) * sw - dd[1] * dsw * rr[1] * inr); // avg[type[(idx_deriv + 4) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 4) % (ndescrpt * 3)) / 3];
vv[5] = (((FPTYPE)2. * rr[0] * rr[2] * inr4 ) * sw - dd[1] * dsw * rr[2] * inr); // avg[type[(idx_deriv + 5) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 5) % (ndescrpt * 3)) / 3];
// ***deriv of component y/r2
vv[6] = ((2. * rr[1] * rr[0] * inr4 ) * sw - dd[2] * dsw * rr[0] * inr); // avg[type[(idx_deriv + 6) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 6) % (ndescrpt * 3)) / 3];
vv[7] = ((2. * rr[1] * rr[1] * inr4 - inr2) * sw - dd[2] * dsw * rr[1] * inr); // avg[type[(idx_deriv + 7) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 7) % (ndescrpt * 3)) / 3];
vv[8] = ((2. * rr[1] * rr[2] * inr4 ) * sw - dd[2] * dsw * rr[2] * inr); // avg[type[(idx_deriv + 8) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 8) % (ndescrpt * 3)) / 3];
vv[6] = (((FPTYPE)2. * rr[1] * rr[0] * inr4 ) * sw - dd[2] * dsw * rr[0] * inr); // avg[type[(idx_deriv + 6) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 6) % (ndescrpt * 3)) / 3];
vv[7] = (((FPTYPE)2. * rr[1] * rr[1] * inr4 - inr2) * sw - dd[2] * dsw * rr[1] * inr); // avg[type[(idx_deriv + 7) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 7) % (ndescrpt * 3)) / 3];
vv[8] = (((FPTYPE)2. * rr[1] * rr[2] * inr4 ) * sw - dd[2] * dsw * rr[2] * inr); // avg[type[(idx_deriv + 8) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 8) % (ndescrpt * 3)) / 3];
// ***deriv of component z/r2
vv[9] = ((2. * rr[2] * rr[0] * inr4 ) * sw - dd[3] * dsw * rr[0] * inr); // avg[type[(idx_deriv + 9) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 9) % (ndescrpt * 3)) / 3];
vv[10]= ((2. * rr[2] * rr[1] * inr4 ) * sw - dd[3] * dsw * rr[1] * inr); // avg[type[(idx_deriv + 10) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 10) % (ndescrpt * 3)) / 3];
vv[11]= ((2. * rr[2] * rr[2] * inr4 - inr2) * sw - dd[3] * dsw * rr[2] * inr); // avg[type[(idx_deriv + 11) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 11) % (ndescrpt * 3)) / 3];
vv[9] = (((FPTYPE)2. * rr[2] * rr[0] * inr4 ) * sw - dd[3] * dsw * rr[0] * inr); // avg[type[(idx_deriv + 9) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 9) % (ndescrpt * 3)) / 3];
vv[10]= (((FPTYPE)2. * rr[2] * rr[1] * inr4 ) * sw - dd[3] * dsw * rr[1] * inr); // avg[type[(idx_deriv + 10) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 10) % (ndescrpt * 3)) / 3];
vv[11]= (((FPTYPE)2. * rr[2] * rr[2] * inr4 - inr2) * sw - dd[3] * dsw * rr[2] * inr); // avg[type[(idx_deriv + 11) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 11) % (ndescrpt * 3)) / 3];
// 4 value components
dd[0] *= sw; // * em[idx * ndescrpt + idx_value + 0]);// - avg[type[idx] * ndescrpt + idx_value + 0]) / std[type[idx] * ndescrpt + idx_value + 0];
dd[1] *= sw; // * em[idx * ndescrpt + idx_value + 1]);// - avg[type[idx] * ndescrpt + idx_value + 1]) / std[type[idx] * ndescrpt + idx_value + 1];
Expand Down Expand Up @@ -431,14 +434,14 @@ __global__ void compute_env_mat_r(
}
// const FPTYPE * rr = &row_rij[ii * 3];
FPTYPE nr2 = dev_dot(rr, rr);
FPTYPE inr = 1./sqrt(nr2);
FPTYPE inr = (FPTYPE)1./_sqrt(nr2);
FPTYPE nr = nr2 * inr;
FPTYPE inr2 = inr * inr;
FPTYPE inr4 = inr2 * inr2;
FPTYPE inr3 = inr4 * nr;
FPTYPE sw, dsw;
spline5_switch(sw, dsw, nr, rmin, rmax);
dd = (1./nr) ;//* sw;
dd = ((FPTYPE)1./nr) ;//* sw;
vv[0] = (rr[0] * inr3 * sw - dd * dsw * rr[0] * inr); // avg[type[(idx_deriv + 0) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 0) % (ndescrpt * 3)) / 3];
vv[1] = (rr[1] * inr3 * sw - dd * dsw * rr[1] * inr); // avg[type[(idx_deriv + 1) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 1) % (ndescrpt * 3)) / 3];
vv[2] = (rr[2] * inr3 * sw - dd * dsw * rr[2] * inr); // avg[type[(idx_deriv + 2) / (ndescrpt * 3)] * ndescrpt + ((idx_deriv + 2) % (ndescrpt * 3)) / 3];
Expand Down
4 changes: 2 additions & 2 deletions source/lib/src/cuda/prod_virial.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ __global__ void atom_virial_reduction(
unsigned int bid = blockIdx.x;
unsigned int tid = threadIdx.x;
__shared__ FPTYPE data[THREADS_PER_BLOCK];
data[tid] = 0.f;
data[tid] = (FPTYPE)0.;
for (int ii = tid; ii < nall; ii += THREADS_PER_BLOCK) {
data[tid] += atom_virial[ii * 9 + bid];
}
Expand Down Expand Up @@ -58,7 +58,7 @@ __global__ void virial_deriv_wrt_neighbors_a(
// atomicAdd(
// virial + idz,
// net_deriv[idx * ndescrpt + idy * 4 + idw] * rij[idx * nnei * 3 + idy * 3 + idz / 3] * in_deriv[idx * ndescrpt * 3 + (idy * 4 + idw) * 3 + idz % 3]);
FPTYPE virial_tmp = 0.f;
FPTYPE virial_tmp = (FPTYPE)0.;
for (int idw = 0; idw < 4; ++idw) {
virial_tmp += net_deriv[idx * ndescrpt + idy * 4 + idw] * rij[idx * nnei * 3 + idy * 3 + idz % 3] * in_deriv[idx * ndescrpt * 3 + (idy * 4 + idw) * 3 + idz / 3];
}
Expand Down
6 changes: 3 additions & 3 deletions source/lib/src/cuda/prod_virial_grad.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ __device__ inline FPTYPE dev_dot9(
const FPTYPE * arr1,
const FPTYPE * arr2)
{
FPTYPE result = 0.0;
FPTYPE result = (FPTYPE)0.0;
for(int ii=0; ii<9; ii++){
result += arr1[ii] * arr2[ii];
}
Expand Down Expand Up @@ -47,7 +47,7 @@ __global__ void virial_grad_wrt_neighbors_a(
tmp[dd0 * 3 + dd1] = rij[idx * nnei * 3 + idy * 3 + dd1] * env_deriv[idx * ndescrpt * 3 + idy * 4 * 3 + idw * 3 + dd0];
}
}
grad_net[idx * ndescrpt + idy * 4 + idw] -= -1.0 * dev_dot9(grad_one, tmp);
grad_net[idx * ndescrpt + idy * 4 + idw] -= (FPTYPE)-1.0 * dev_dot9(grad_one, tmp);
}

template<typename FPTYPE>
Expand Down Expand Up @@ -83,7 +83,7 @@ __global__ void virial_grad_wrt_neighbors_r(
tmp[dd0 * 3 + dd1] = rij[idx * nnei * 3 + idy * 3 + dd1] * env_deriv[idx * ndescrpt * 3 + idy * 3 + dd0];
}
}
grad_net[idx * ndescrpt + idy] -= -1.0 * dev_dot9(grad_one, tmp);
grad_net[idx * ndescrpt + idy] -= (FPTYPE)-1.0 * dev_dot9(grad_one, tmp);
}

namespace deepmd {
Expand Down
Loading