Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions deepmd/tf/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ def _init_attr(self) -> None:
self.numb_dos = 0
self.tmap = tmap.decode("utf-8").split()
if self.tensors["modifier_type"] is not None:
self.modifier_type = run_sess(self.sess, [self.tensors["modifier_type"]])[0]
self.modifier_type = run_sess(self.sess, [self.tensors["modifier_type"]])[
0
].decode()
Comment thread
njzjz marked this conversation as resolved.
else:
self.modifier_type = None

Expand Down Expand Up @@ -761,15 +763,17 @@ def eval(
odef.name: oo for oo, odef in zip(output, self.output_def.var_defs.values())
}
# ugly!!
if self.modifier_type is not None and isinstance(self.model_type, DeepPot):
if self.modifier_type is not None and issubclass(self.model_type, DeepPot):
Comment thread
njzjz marked this conversation as resolved.
if atomic:
raise RuntimeError("modifier does not support atomic modification")
me, mf, mv = self.dm.eval(coords, cells, atom_types)
output = list(output) # tuple to list
e, f, v = output[:3]
output_dict["energy_redu"] += me.reshape(e.shape)
output_dict["energy_deri_r"] += mf.reshape(f.shape)
output_dict["energy_deri_c_redu"] += mv.reshape(v.shape)
output_dict["energy_redu"] += me.reshape(output_dict["energy_redu"].shape)
output_dict["energy_derv_r"] += mf.reshape(
output_dict["energy_derv_r"].shape
)
output_dict["energy_derv_c_redu"] += mv.reshape(
output_dict["energy_derv_c_redu"].shape
)
Comment thread
njzjz marked this conversation as resolved.
return output_dict

def _prepare_feed_dict(
Expand Down Expand Up @@ -1348,6 +1352,8 @@ def sort_input(
natoms = atom_type[0].size
idx_map = np.arange(natoms) # pylint: disable=no-explicit-dtype
return coord, atom_type, idx_map
if atom_type.ndim > 1:
atom_type = atom_type[0]
Comment thread
njzjz marked this conversation as resolved.
if sel_atoms is not None:
selection = [False] * np.size(atom_type)
for ii in sel_atoms:
Expand Down
Loading