diff --git a/source/tests/pd/common.py b/source/tests/pd/common.py index d73544c5f1..ec36fd0eb9 100644 --- a/source/tests/pd/common.py +++ b/source/tests/pd/common.py @@ -79,7 +79,12 @@ def eval_model( if spins is not None: assert isinstance(spins, paddle.Tensor), err_msg assert isinstance(atom_types, paddle.Tensor) or isinstance(atom_types, list) - atom_types = paddle.to_tensor(atom_types, dtype=paddle.int32, place=DEVICE) + if isinstance(atom_types, paddle.Tensor): + atom_types = ( + atom_types.clone().detach().to(dtype=paddle.int32, device=DEVICE) + ) + else: + atom_types = paddle.to_tensor(atom_types, dtype=paddle.int32, place=DEVICE) elif isinstance(coords, np.ndarray): if cells is not None: assert isinstance(cells, np.ndarray), err_msg @@ -101,28 +106,57 @@ def eval_model( else: natoms = len(atom_types[0]) - coord_input = paddle.to_tensor( - coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE - ) - spin_input = None - if spins is not None: - spin_input = paddle.to_tensor( - spins.reshape([-1, natoms, 3]), + if isinstance(coords, paddle.Tensor): + coord_input = ( + coords.reshape([-1, natoms, 3]) + .clone() + .detach() + .to(dtype=GLOBAL_PD_FLOAT_PRECISION, device=DEVICE) + ) + else: + coord_input = paddle.to_tensor( + coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE, ) + spin_input = None + if spins is not None: + if isinstance(spins, paddle.Tensor): + spin_input = ( + spins.reshape([-1, natoms, 3]) + .clone() + .detach() + .to(dtype=GLOBAL_PD_FLOAT_PRECISION, device=DEVICE) + ) + else: + spin_input = paddle.to_tensor( + spins.reshape([-1, natoms, 3]), + dtype=GLOBAL_PD_FLOAT_PRECISION, + place=DEVICE, + ) has_spin = getattr(model, "has_spin", False) if callable(has_spin): has_spin = has_spin() - type_input = paddle.to_tensor(atom_types, dtype=paddle.int64, place=DEVICE) + if isinstance(atom_types, paddle.Tensor): + type_input = atom_types.clone().detach().to(dtype=paddle.int64, device=DEVICE) + else: + type_input = paddle.to_tensor(atom_types, dtype=paddle.int64, place=DEVICE) box_input = None if cells is None: pbc = False else: pbc = True - box_input = paddle.to_tensor( - cells.reshape([-1, 3, 3]), dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE - ) + if isinstance(cells, paddle.Tensor): + box_input = ( + cells.reshape([-1, 3, 3]) + .clone() + .detach() + .to(dtype=GLOBAL_PD_FLOAT_PRECISION, device=DEVICE) + ) + else: + box_input = paddle.to_tensor( + cells.reshape([-1, 3, 3]), dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE + ) num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size) for ii in range(num_iter): diff --git a/source/tests/pt/common.py b/source/tests/pt/common.py index 8709c8b4f9..2dbfdb84ff 100644 --- a/source/tests/pt/common.py +++ b/source/tests/pt/common.py @@ -79,7 +79,12 @@ def eval_model( if spins is not None: assert isinstance(spins, torch.Tensor), err_msg assert isinstance(atom_types, torch.Tensor) or isinstance(atom_types, list) - atom_types = torch.tensor(atom_types, dtype=torch.int32, device=DEVICE) + if isinstance(atom_types, torch.Tensor): + atom_types = ( + atom_types.clone().detach().to(dtype=torch.int32, device=DEVICE) + ) + else: + atom_types = torch.tensor(atom_types, dtype=torch.int32, device=DEVICE) elif isinstance(coords, np.ndarray): if cells is not None: assert isinstance(cells, np.ndarray), err_msg @@ -101,28 +106,59 @@ def eval_model( else: natoms = len(atom_types[0]) - coord_input = torch.tensor( - coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - spin_input = None - if spins is not None: - spin_input = torch.tensor( - spins.reshape([-1, natoms, 3]), + if isinstance(coords, torch.Tensor): + coord_input = ( + coords.reshape([-1, natoms, 3]) + .clone() + .detach() + .to(dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE) + ) + else: + coord_input = torch.tensor( + coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE, ) + spin_input = None + if spins is not None: + if isinstance(spins, torch.Tensor): + spin_input = ( + spins.reshape([-1, natoms, 3]) + .clone() + .detach() + .to(dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE) + ) + else: + spin_input = torch.tensor( + spins.reshape([-1, natoms, 3]), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) has_spin = getattr(model, "has_spin", False) if callable(has_spin): has_spin = has_spin() - type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE) + if isinstance(atom_types, torch.Tensor): + type_input = atom_types.clone().detach().to(dtype=torch.long, device=DEVICE) + else: + type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE) box_input = None if cells is None: pbc = False else: pbc = True - box_input = torch.tensor( - cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) + if isinstance(cells, torch.Tensor): + box_input = ( + cells.reshape([-1, 3, 3]) + .clone() + .detach() + .to(dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE) + ) + else: + box_input = torch.tensor( + cells.reshape([-1, 3, 3]), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size) for ii in range(num_iter): diff --git a/source/tests/pt/test_calculator.py b/source/tests/pt/test_calculator.py index 860a161fbd..7458117ca3 100644 --- a/source/tests/pt/test_calculator.py +++ b/source/tests/pt/test_calculator.py @@ -64,14 +64,18 @@ def test_calculator(self) -> None: atomic_numbers = [1, 1, 1, 8, 8] idx_perm = [1, 0, 4, 3, 2] + # Convert tensors to numpy for ASE compatibility + cell_np = cell.numpy() + coord_np = coord.numpy() + prec = 1e-10 low_prec = 1e-4 ase_atoms0 = Atoms( numbers=atomic_numbers, - positions=coord, + positions=coord_np, # positions=[tuple(item) for item in coordinate], - cell=cell, + cell=cell_np, calculator=self.calculator, pbc=True, ) @@ -83,9 +87,9 @@ def test_calculator(self) -> None: ase_atoms1 = Atoms( numbers=[atomic_numbers[i] for i in idx_perm], - positions=coord[idx_perm, :], + positions=coord_np[idx_perm, :], # positions=[tuple(item) for item in coordinate], - cell=cell, + cell=cell_np, calculator=self.calculator, pbc=True, ) @@ -141,19 +145,23 @@ def test_calculator(self) -> None: generator = torch.Generator(device="cpu").manual_seed(GLOBAL_SEED) coord = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator) coord = torch.matmul(coord, cell) - fparam = torch.IntTensor([1, 2]) - aparam = torch.IntTensor([[1], [0], [2], [1], [0]]) + fparam = torch.IntTensor([1, 2]).numpy() + aparam = torch.IntTensor([[1], [0], [2], [1], [0]]).numpy() atomic_numbers = [1, 1, 1, 8, 8] idx_perm = [1, 0, 4, 3, 2] + # Convert tensors to numpy for ASE compatibility + cell_np = cell.numpy() + coord_np = coord.numpy() + prec = 1e-10 low_prec = 1e-4 ase_atoms0 = Atoms( numbers=atomic_numbers, - positions=coord, + positions=coord_np, # positions=[tuple(item) for item in coordinate], - cell=cell, + cell=cell_np, calculator=self.calculator, pbc=True, ) @@ -166,9 +174,9 @@ def test_calculator(self) -> None: ase_atoms1 = Atoms( numbers=[atomic_numbers[i] for i in idx_perm], - positions=coord[idx_perm, :], + positions=coord_np[idx_perm, :], # positions=[tuple(item) for item in coordinate], - cell=cell, + cell=cell_np, calculator=self.calculator, pbc=True, )