diff --git a/source/tests/pt_expt/model/test_export_with_comm.py b/source/tests/pt_expt/model/test_export_with_comm.py index ec305f2ed0..f34b20932f 100644 --- a/source/tests/pt_expt/model/test_export_with_comm.py +++ b/source/tests/pt_expt/model/test_export_with_comm.py @@ -155,12 +155,14 @@ def test_pt2_dual_artifact_for_gnn(tmp_path) -> None: # 3. Run both artifacts with nframes=1 (matches what the with-comm # artifact requires; LAMMPS always passes one frame anyway). sample = _make_sample_inputs(model, nframes=1, has_spin=False) - ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = sample + ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin = sample nloc = nlist_t.shape[1] nall = ext_atype.shape[1] nghost = nall - nloc - out_regular = regular(ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam) + out_regular = regular( + ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin + ) # 4. Build runtime comm tensors mirroring the mapping (single-rank # self-send: ghost slot ii receives node[mapping[ii]], identical to @@ -181,6 +183,7 @@ def test_pt2_dual_artifact_for_gnn(tmp_path) -> None: mapping_t, fparam, aparam, + charge_spin, *comm_inputs, )