Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
35d5577
Perf: replace unnecessary `torch.split` with indexing (#4505)
caic99 Dec 25, 2024
dfaa2ba
docs: fix the header of the scaling test table (#4507)
njzjz Dec 26, 2024
a44c6ca
Fix: Modify docs of DPA models (#4510)
QuantumMisaka Dec 26, 2024
43c8cae
fix(pt): fix clearing the list in set_eval_descriptor_hook (#4534)
njzjz Jan 7, 2025
2b7f53c
[fix bug] load atomic_*.npy for tf tensor model (#4538)
ChiahsinChu Jan 7, 2025
cc37fd1
fix: lower `num_workers` to 4 (#4535)
caic99 Jan 7, 2025
5f9bbc4
feat(tf): support tensor fitting with hybrid descriptor (#4542)
njzjz Jan 10, 2025
4cb16a1
docs: add `sphinx.configuration` to .readthedocs.yml (#4553)
njzjz Jan 16, 2025
476f34d
Perf: use F.linear for MLP (#4513)
caic99 Jan 16, 2025
f199a2c
CI: switch linux_aarch64 to GitHub hosted runners (#4557)
njzjz Jan 17, 2025
bbbb426
chore: improve neighbor stat log (#4561)
njzjz Jan 20, 2025
4fc1d89
fix: fix YAML conversion (#4565)
njzjz Jan 21, 2025
cfa5064
fix(cc): remove C++ 17 usage (#4570)
njzjz Jan 24, 2025
a1290ea
chore: bump pytorch to 2.6.0 (#4575)
njzjz Feb 2, 2025
0bc73e6
Fix version in DeePMDConfigVersion.cmake (#4577)
RMeli Feb 4, 2025
44e40bf
fix(pt): detach computed descriptor tensor to prevent OOM (#4547)
njzjz Feb 5, 2025
fca0e6e
fix(pt): throw errors for GPU tensors and the CPU OP library (#4582)
njzjz Feb 7, 2025
db9071e
use variable to store the bias of atomic polarizability (#4581)
Yi-FanLi Feb 8, 2025
910b8a7
Fix: pt tensor loss label name (#4587)
anyangml Feb 8, 2025
be468b1
Fix UT `test_tf_consistent_with_ref` (#216)
Yi-FanLi Feb 10, 2025
c2e6b6a
CI: pin jax to 0.5.0 (#4613)
njzjz Feb 26, 2025
2580af8
docs: add v3 paper citations (#4619)
njzjz Feb 27, 2025
6d6c3fd
fix(array-api): fix xp.where errors (#4624)
njzjz Mar 1, 2025
b59bc33
docs: add PyTorch Profiler support details to TensorBoard documentati…
caic99 Mar 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .devcontainer/download_libtorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ set -ev
SCRIPT_PATH=$(dirname $(realpath -s $0))
cd ${SCRIPT_PATH}/..

wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.5.0%2Bcpu.zip -O ~/libtorch.zip
wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.6.0%2Bcpu.zip -O ~/libtorch.zip
unzip ~/libtorch.zip
18 changes: 1 addition & 17 deletions .github/workflows/build_wheel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,8 @@ concurrency:
cancel-in-progress: true

jobs:
determine-arm64-runner:
runs-on: ubuntu-latest
permissions: read-all
outputs:
runner: ${{ steps.set-runner.outputs.runner }}
steps:
- name: Determine which runner to use for ARM64 build
id: set-runner
run: |
if [ "${{ github.repository_owner }}" == "deepmodeling" ]; then
echo "runner=[\"Linux\",\"ARM64\"]" >> $GITHUB_OUTPUT
else
echo "runner=\"ubuntu-latest\"" >> $GITHUB_OUTPUT
fi

build_wheels:
name: Build wheels for cp${{ matrix.python }}-${{ matrix.platform_id }}
needs: determine-arm64-runner
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
Expand Down Expand Up @@ -65,7 +49,7 @@ jobs:
platform_id: win_amd64
dp_variant: cpu
# linux-aarch64
- os: ${{ fromJson(needs.determine-arm64-runner.outputs.runner) }}
- os: ubuntu-24.04-arm
python: 310
platform_id: manylinux_aarch64
dp_variant: cpu
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_cc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
- run: python -m pip install uv
- name: Install Python dependencies
run: |
source/install/uv_with_retry.sh pip install --system tensorflow-cpu
source/install/uv_with_retry.sh pip install --system tensorflow-cpu~=2.18.0 jax==0.5.0
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp,jax] mpi4py
- name: Convert models
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ jobs:
&& sudo apt-get -y install cuda-12-3 libcudnn8=8.9.5.*-1+cuda12.3
if: false # skip as we use nvidia image
- run: python -m pip install -U uv
- run: source/install/uv_with_retry.sh pip install --system "tensorflow~=2.18.0rc2" "torch~=2.5.0" "jax[cuda12]"
- run: source/install/uv_with_retry.sh pip install --system "tensorflow~=2.18.0rc2" "torch~=2.6.0" "jax[cuda12]==0.5.0"
- run: |
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
source/install/uv_with_retry.sh pip install --system -v -e .[gpu,test,lmp,cu12,torch,jax] mpi4py
source/install/uv_with_retry.sh pip install --system -v -e .[gpu,test,lmp,cu12,torch,jax] mpi4py --reinstall-package deepmd-kit
env:
DP_VARIANT: cuda
DP_ENABLE_NATIVE_OPTIMIZATION: 1
Expand All @@ -67,7 +67,7 @@ jobs:
run: source/tests/infer/convert-models.sh
- name: Download libtorch
run: |
wget https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.5.0%2Bcu124.zip -O libtorch.zip
wget https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.6.0%2Bcu124.zip -O libtorch.zip
unzip libtorch.zip
- run: |
export CMAKE_PREFIX_PATH=$GITHUB_WORKSPACE/libtorch
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ jobs:
python-version: ${{ matrix.python }}
- run: python -m pip install -U uv
- run: |
source/install/uv_with_retry.sh pip install --system openmpi tensorflow-cpu
source/install/uv_with_retry.sh pip install --system openmpi tensorflow-cpu~=2.18.0
source/install/uv_with_retry.sh pip install --system torch -i https://download.pytorch.org/whl/cpu
export TENSORFLOW_ROOT=$(python -c 'import tensorflow;print(tensorflow.__path__[0])')
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py
source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py "jax==0.5.0;python_version>='3.10'"
source/install/uv_with_retry.sh pip install --system horovod --no-build-isolation
env:
# Please note that uv has some issues with finding
Expand Down
2 changes: 2 additions & 0 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@ build:
- VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH uv pip install -r doc/requirements.txt
apt_packages:
- inkscape
sphinx:
configuration: doc/conf.py
formats:
- pdf
25 changes: 25 additions & 0 deletions CITATIONS.bib
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,31 @@ @article{Zeng_JChemPhys_2023_v159_p054801
doi = {10.1063/5.0155600},
}

@article{Zeng_arXiv_2025_p2502.19161,
annote = {general purpose},
author = {
Jinzhe Zeng and Duo Zhang and Anyang Peng and Xiangyu Zhang and Sensen He
and Yan Wang and Xinzijian Liu and Hangrui Bi and Yifan Li and Chun Cai and
Chengqian Zhang and Yiming Du and Jia-Xin Zhu and Pinghui Mo and Zhengtao
Huang and Qiyu Zeng and Shaochen Shi and Xuejian Qin and Zhaoxi Yu and
Chenxing Luo and Ye Ding and Yun-Pei Liu and Ruosong Shi and Zhenyu Wang
and Sigbj{\o}rn L{\o}land Bore and Junhan Chang and Zhe Deng and Zhaohan
Ding and Siyuan Han and Wanrun Jiang and Guolin Ke and Zhaoqing Liu and
Denghui Lu and Koki Muraoka and Hananeh Oliaei and Anurag Kumar Singh and
Haohui Que and Weihong Xu and Zhangmancang Xu and Yong-Bin Zhuang and Jiayu
Dai and Timothy J. Giese and Weile Jia and Ben Xu and Darrin M. York and
Linfeng Zhang and Han Wang
},
title = {
{DeePMD-kit v3: A Multiple-Backend Framework for Machine Learning
Potentials}
},
journal = {arXiv},
year = 2025,
pages = {2502.19161},
doi = {10.48550/arXiv.2502.19161},
}

@article{Lu_CompPhysCommun_2021_v259_p107624,
annote = {GPU support},
title = {
Expand Down
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ If you use this code in any future publications, please cite the following publi
- Jinzhe Zeng, Duo Zhang, Denghui Lu, Pinghui Mo, Zeyu Li, Yixiao Chen, Marián Rynik, Li'ang Huang, Ziyao Li, Shaochen Shi, Yingze Wang, Haotian Ye, Ping Tuo, Jiabin Yang, Ye Ding, Yifan Li, Davide Tisi, Qiyu Zeng, Han Bao, Yu Xia, Jiameng Huang, Koki Muraoka, Yibo Wang, Junhan Chang, Fengbo Yuan, Sigbjørn Løland Bore, Chun Cai, Yinnian Lin, Bo Wang, Jiayan Xu, Jia-Xin Zhu, Chenxing Luo, Yuzhi Zhang, Rhys E. A. Goodall, Wenshuo Liang, Anurag Kumar Singh, Sikai Yao, Jingchao Zhang, Renata Wentzcovitch, Jiequn Han, Jie Liu, Weile Jia, Darrin M. York, Weinan E, Roberto Car, Linfeng Zhang, Han Wang. "DeePMD-kit v2: A software package for deep potential models." J. Chem. Phys. 159 (2023): 054801.
[![doi:10.1063/5.0155600](https://img.shields.io/badge/DOI-10.1063%2F5.0155600-blue)](https://doi.org/10.1063/5.0155600)
[![Citations](https://citations.njzjz.win/10.1063/5.0155600)](https://badge.dimensions.ai/details/doi/10.1063/5.0155600)
- Jinzhe Zeng, Duo Zhang, Anyang Peng, Xiangyu Zhang, Sensen He, Yan Wang, Xinzijian Liu, Hangrui Bi, Yifan Li, Chun Cai, Chengqian Zhang, Yiming Du, Jia-Xin Zhu, Pinghui Mo, Zhengtao Huang, Qiyu Zeng, Shaochen Shi, Xuejian Qin, Zhaoxi Yu, Chenxing Luo, Ye Ding, Yun-Pei Liu, Ruosong Shi, Zhenyu Wang, Sigbjørn Løland Bore, Junhan Chang, Zhe Deng, Zhaohan Ding, Siyuan Han, Wanrun Jiang, Guolin Ke, Zhaoqing Liu, Denghui Lu, Koki Muraoka, Hananeh Oliaei, Anurag Kumar Singh, Haohui Que, Weihong Xu, Zhangmancang Xu, Yong-Bin Zhuang, Jiayu Dai, Timothy J. Giese, Weile Jia, Ben Xu, Darrin M. York, Linfeng Zhang, Han Wang. "DeePMD-kit v3: A Multiple-Backend Framework for Machine Learning Potentials." [arXiv:2502.19161](https://arxiv.org/abs/2502.19161).

In addition, please follow [the bib file](CITATIONS.bib) to cite the methods you used.

Expand Down Expand Up @@ -68,14 +69,16 @@ In addition to building up potential energy models, DeePMD-kit can also be used
- Non-von-Neumann.
- C API to interface with the third-party packages.

See [our latest paper](https://doi.org/10.1063/5.0155600) for details of all features until v2.2.3.
See [our v2 paper](https://doi.org/10.1063/5.0155600) for details of all features until v2.2.3.

#### v3

- Multiple backends supported. Add PyTorch and JAX backends.
- The DPA-2 model.
- Plugin mechanisms for external models.

See [our v3 paper](https://doi.org/10.48550/arXiv.2502.19161) for details of all features until v3.0.

## Install and use DeePMD-kit

Please read the [online documentation](https://deepmd.readthedocs.io/) for how to install and use DeePMD-kit.
Expand Down
2 changes: 1 addition & 1 deletion backend/find_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def get_pt_requirement(pt_version: str = "") -> dict:
cuda_version = os.environ.get("CUDA_VERSION", "12.2")
if cuda_version == "" or cuda_version in SpecifierSet(">=12,<13"):
# CUDA 12.2, cudnn 9
pt_version = "2.5.0"
pt_version = "2.6.0"
elif cuda_version in SpecifierSet(">=11,<12"):
# CUDA 11.8, cudnn 8
pt_version = "2.3.1"
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,7 @@ def call(
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
# nfnl x nnei
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
exclude_mask = xp.astype(exclude_mask, xp.bool)
# nfnl x nnei
nlist = xp.reshape(nlist, (nf * nloc, nnei))
nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1))
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def call(
):
xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext)
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
exclude_mask = xp.astype(exclude_mask, xp.bool)
nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1))
# nf x nloc x nnei x 4
dmatrix, diff, sw = self.env_mat.call(
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,7 @@ def call(
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
# nfnl x nnei
nlist = xp.reshape(nlist, (nf * nloc, nnei))
exclude_mask = xp.astype(exclude_mask, xp.bool)
nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1))
# nfnl x nnei
nlist_mask = nlist != -1
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ def _call_common(
)
# nf x nloc
exclude_mask = self.emask.build_type_exclude_mask(atype)
exclude_mask = xp.astype(exclude_mask, xp.bool)
# nf x nloc x nod
outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs))
return {self.var_name: outs}
4 changes: 3 additions & 1 deletion deepmd/dpmodel/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def save_dp_model(filename: str, model_dict: dict) -> None:
"@version": 1,
"dtype": x.dtype.name,
"value": x.tolist(),
},
}
if isinstance(x, np.ndarray)
else x,
)
with open(filename, "w") as f:
yaml.safe_dump(
Expand Down
5 changes: 3 additions & 2 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def __init__(
def set_eval_descriptor_hook(self, enable: bool) -> None:
"""Set the hook for evaluating descriptor and clear the cache for descriptor list."""
self.enable_eval_descriptor_hook = enable
self.eval_descriptor_list = []
# = [] does not work; See #4533
self.eval_descriptor_list.clear()

def eval_descriptor(self) -> torch.Tensor:
"""Evaluate the descriptor."""
Expand Down Expand Up @@ -236,7 +237,7 @@ def forward_atomic(
)
assert descriptor is not None
if self.enable_eval_descriptor_hook:
self.eval_descriptor_list.append(descriptor)
self.eval_descriptor_list.append(descriptor.detach())
# energy, force
fit_ret = self.fitting_net(
descriptor,
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ def _cal_grrg(h2g2: torch.Tensor, axis_neuron: int) -> torch.Tensor:
# nb x nloc x 3 x ng2
nb, nloc, _, ng2 = h2g2.shape
# nb x nloc x 3 x axis
h2g2m = torch.split(h2g2, axis_neuron, dim=-1)[0]
h2g2m = h2g2[..., :axis_neuron]
# nb x nloc x axis x ng2
g1_13 = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1)
# nb x nloc x (axisxng2)
Expand Down
13 changes: 5 additions & 8 deletions deepmd/pt/model/network/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from deepmd.pt.utils import (
env,
Expand Down Expand Up @@ -202,18 +203,14 @@ def forward(
ori_prec = xx.dtype
if not env.DP_DTYPE_PROMOTION_STRICT:
xx = xx.to(self.prec)
yy = (
torch.matmul(xx, self.matrix) + self.bias
if self.bias is not None
else torch.matmul(xx, self.matrix)
)
yy = self.activate(yy).clone()
yy = F.linear(xx, self.matrix.t(), self.bias)
yy = self.activate(yy)
yy = yy * self.idt if self.idt is not None else yy
if self.resnet:
if xx.shape[-1] == yy.shape[-1]:
yy += xx
yy = yy + xx
elif 2 * xx.shape[-1] == yy.shape[-1]:
yy += torch.concat([xx, xx], dim=-1)
yy = yy + torch.concat([xx, xx], dim=-1)
else:
yy = yy
if not env.DP_DTYPE_PROMOTION_STRICT:
Expand Down
10 changes: 4 additions & 6 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,13 +1230,11 @@ def get_loss(loss_params, start_lr, _ntypes, _model):
if "mask" in model_output_type:
model_output_type.pop(model_output_type.index("mask"))
tensor_name = model_output_type[0]
loss_params["tensor_name"] = tensor_name
loss_params["tensor_size"] = _model.model_output_def()[tensor_name].output_size
label_name = tensor_name
if label_name == "polarizability":
label_name = "polar"
loss_params["label_name"] = label_name
loss_params["tensor_name"] = label_name
loss_params["label_name"] = tensor_name
if tensor_name == "polarizability":
tensor_name = "polar"
loss_params["tensor_name"] = tensor_name
return TensorLoss(**loss_params)
elif loss_type == "property":
task_dim = _model.get_task_dim()
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
ncpus = len(os.sched_getaffinity(0))
except AttributeError:
ncpus = os.cpu_count()
NUM_WORKERS = int(os.environ.get("NUM_WORKERS", min(8, ncpus)))
NUM_WORKERS = int(os.environ.get("NUM_WORKERS", min(4, ncpus)))
# Make sure DDP uses correct device if applicable
LOCAL_RANK = os.environ.get("LOCAL_RANK")
LOCAL_RANK = int(0 if LOCAL_RANK is None else LOCAL_RANK)
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def nlist_distinguish_types(
inlist = torch.gather(nlist, 2, imap)
inlist = inlist.masked_fill(~(pick_mask.to(torch.bool)), -1)
# nloc x nsel[ii]
ret_nlist.append(torch.split(inlist, [ss, snsel - ss], dim=-1)[0])
ret_nlist.append(inlist[..., :ss])
return torch.concat(ret_nlist, dim=-1)


Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,9 @@ def compute_output_stats_global(
# subtract the model bias and output the delta bias

stats_input = {
kk: merged_output[kk] - model_pred[kk] for kk in keys if kk in merged_output
kk: merged_output[kk] - model_pred[kk].reshape(merged_output[kk].shape)
for kk in keys
if kk in merged_output
}

bias_atom_e = {}
Expand Down
9 changes: 8 additions & 1 deletion deepmd/tf/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def get_dim_rot_mat_1(self) -> int:
int
the first dimension of the rotation matrix
"""
raise NotImplementedError
# by default, no rotation matrix
return 0

def get_nlist(self) -> tuple[tf.Tensor, tf.Tensor, list[int], list[int]]:
"""Returns neighbor information.
Expand Down Expand Up @@ -534,3 +535,9 @@ def serialize(self, suffix: str = "") -> dict:
def input_requirement(self) -> list[DataRequirementItem]:
"""Return data requirements needed for the model input."""
return []

def get_rot_mat(self) -> tf.Tensor:
"""Get rotational matrix."""
nframes = tf.shape(self.dout)[0]
natoms = tf.shape(self.dout)[1]
return tf.zeros([nframes, natoms, 0], dtype=GLOBAL_TF_FLOAT_PRECISION)
18 changes: 18 additions & 0 deletions deepmd/tf/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,3 +492,21 @@ def deserialize(cls, data: dict, suffix: str = "") -> "DescrptHybrid":
if hasattr(ii, "type_embedding"):
raise NotImplementedError("hybrid + type embedding is not supported")
return obj

def get_dim_rot_mat_1(self) -> int:
"""Returns the first dimension of the rotation matrix. The rotation is of shape
dim_1 x 3.

Returns
-------
int
the first dimension of the rotation matrix
"""
return sum([ii.get_dim_rot_mat_1() for ii in self.descrpt_list])

def get_rot_mat(self) -> tf.Tensor:
"""Get rotational matrix."""
all_rot_mat = []
for ii in self.descrpt_list:
all_rot_mat.append(ii.get_rot_mat())
return tf.concat(all_rot_mat, axis=2)
Loading