Skip to content

Commit ad1ab0d

Browse files
author
Lukas Petersen
committed
feat: add distance calculation and r_max handling for improved edge and angle processing
1 parent 8c8ea9c commit ad1ab0d

File tree

5 files changed

+178
-100
lines changed

5 files changed

+178
-100
lines changed

src/gromacs/mdlib/qm_pytorch.cpp

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ real call_pytorch(QMMM_rec* qr,
684684
if (n_active_models > 1) {
685685
// Save the means
686686
fprintf(f_std, "%d\n", n);
687-
fprintf(f_std, "means step=%d energy_mean=%8.4f\n", step, energy_mean);
687+
fprintf(f_std, "Means step %d: %8.4f\n", step, energy_mean);
688688

689689
for (int i=0; i<n; i++) {
690690
fprintf(f_std, "%-2s %8.4f %8.4f %8.4f\n",
@@ -694,7 +694,7 @@ real call_pytorch(QMMM_rec* qr,
694694

695695
// Save the stds
696696
fprintf(f_std, "%d\n", n);
697-
fprintf(f_std, "stds step=%d energy_std=%8.4f\n", step, energy_std);
697+
fprintf(f_std, "Stds step %d: %8.4f\n", step, energy_std);
698698
for (int i=0; i<n; i++) {
699699
fprintf(f_std, "%-2s %8.4f %8.4f %8.4f\n",
700700
periodic_system[qm->atomicnumberQM_get(i)],
@@ -775,15 +775,14 @@ void prepare_base_mace_inputs(QMMM_QMrec* qm,
775775
continue;
776776
}
777777

778-
double r_squared = 0.0;
779-
for (int k=0; k<3; k++) {
780-
double delta = (qm->xQM_get(i,k) - qm->xQM_get(j,k)) * geometry_conversion;
781-
r_squared += delta*delta;
782-
}
783-
if (r_squared < qm->r_max_squared) {
784-
n_edges += 1;
785-
n_edges_vec[i] += 1;
778+
float r = calculate_distance(qm, i, j) * geometry_conversion;
779+
780+
if (r > qm->r_max) {
781+
continue;
786782
}
783+
784+
n_edges++;
785+
n_edges_vec[i] += 1;
787786
}
788787
}
789788

@@ -806,12 +805,8 @@ void prepare_base_mace_inputs(QMMM_QMrec* qm,
806805
continue;
807806
}
808807

809-
double r_squared = 0.0;
810-
for (int k=0; k<3; k++) {
811-
double delta = (qm->xQM_get(i,k) - qm->xQM_get(j,k)) * geometry_conversion;
812-
r_squared += delta*delta;
813-
}
814-
if (r_squared > qm->r_max_squared) {
808+
float r = calculate_distance(qm, i, j) * geometry_conversion;
809+
if (r > qm->r_max) {
815810
continue;
816811
}
817812

@@ -1124,6 +1119,12 @@ void write_base_mace_inputs_outputs(QMMM_QMrec* qm,
11241119
}
11251120
// fclose(f_output);
11261121

1122+
// Only proceed to write extended xyz if the model is mace
1123+
if (std::strcmp(qm->models[0]->modelArchitecture, "mace") != 0)
1124+
{
1125+
return;
1126+
}
1127+
11271128
// Write all inputs and outputs in the extended xyz format
11281129
char periodic_system[37][3]={"XX",
11291130
"H", "He",
@@ -1523,6 +1524,22 @@ c10::Dict<std::string, torch::Tensor> convertDict(QMMM_QMrec* qm, const c10::imp
15231524

15241525
return outputDict;
15251526
} // end of convertDict
1527+
1528+
/****************************************
1529+
****** AUXILIARY ROUTINES **********
1530+
****************************************/
1531+
1532+
// adopted from src/gmxlib/pbc.c and qmmm-calculation.cpp
1533+
real calculate_distance(QMMM_QMrec* qm, const int x1_idx, const int x2_idx)
1534+
{
1535+
rvec bond;
1536+
for(int i=0; i<DIM; i++) {
1537+
bond[i] = qm->xQM_get(x1_idx,i) - qm->xQM_get(x2_idx,i);
1538+
}
1539+
1540+
return norm(bond);
1541+
}
1542+
15261543
/* end of NN sub routines */
15271544
#endif
15281545

src/gromacs/mdlib/qm_pytorch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ void write_amp_inputs_outputs(QMMM_QMrec* qm,
8080
int step);
8181

8282
c10::Dict<std::string, torch::Tensor> convertDict(QMMM_QMrec* qm, const c10::impl::GenericDict& inputDict);
83+
float calculate_distance(QMMM_QMrec* qm, const int x1_idx, const int x2_idx);
8384
#endif
8485

8586
#endif

0 commit comments

Comments
 (0)