Skip to content
12 changes: 11 additions & 1 deletion deepmd/tf/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,17 @@ def _pass_filter(
tf.shape(inputs_i)[0],
self.nei_type_vec, # extra input for atten
)
inputs_i *= mask
if self.smooth:
inputs_i = tf.where(
tf.cast(mask, tf.bool),
inputs_i,
# (nframes * nloc, 1) -> (nframes * nloc, ndescrpt)
tf.tile(
tf.reshape(self.avg_looked_up, [-1, 1]), [1, self.ndescrpt]
),
)
else:
inputs_i *= mask
if nvnmd_cfg.enable and nvnmd_cfg.quantize_descriptor:
inputs_i = descrpt2r4(inputs_i, atype)
layer, qmat = self._filter(
Expand Down