From ff7767c2a867aa83caaac973a1b5b524e84c5245 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 22 Nov 2024 02:49:45 -0500 Subject: [PATCH 1/5] fix(pt): use eval mode in the C++ interface Signed-off-by: Jinzhe Zeng --- source/api_cc/src/DeepPotPT.cc | 1 + source/api_cc/src/DeepSpinPT.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 7e5d391b1f..b431ad65cf 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -88,6 +88,7 @@ void DeepPotPT::init(const std::string& model, } std::unordered_map metadata = {{"type", ""}}; module = torch::jit::load(model, device, metadata); + module.eval(); do_message_passing = module.run_method("has_message_passing").toBool(); torch::jit::FusionStrategy strategy; strategy = {{torch::jit::FusionBehavior::DYNAMIC, 10}}; diff --git a/source/api_cc/src/DeepSpinPT.cc b/source/api_cc/src/DeepSpinPT.cc index c72cb34b15..1a245c7b2e 100644 --- a/source/api_cc/src/DeepSpinPT.cc +++ b/source/api_cc/src/DeepSpinPT.cc @@ -88,6 +88,7 @@ void DeepSpinPT::init(const std::string& model, } std::unordered_map metadata = {{"type", ""}}; module = torch::jit::load(model, device, metadata); + module.eval(); do_message_passing = module.run_method("has_message_passing").toBool(); torch::jit::FusionStrategy strategy; strategy = {{torch::jit::FusionBehavior::DYNAMIC, 10}}; From e4453720c71d80af0f92e196ab1bf9f69b4b254e Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 22 Nov 2024 13:58:29 -0500 Subject: [PATCH 2/5] Freeze TorchScript modules with preserved attributes --- source/api_cc/src/DeepPotPT.cc | 2 ++ source/api_cc/src/DeepSpinPT.cc | 2 ++ 2 files changed, 4 insertions(+) diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index b431ad65cf..4f17c30e25 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -89,6 +89,8 @@ void DeepPotPT::init(const std::string& model, std::unordered_map metadata = {{"type", ""}}; module = torch::jit::load(model, device, metadata); module.eval(); + const std::vector& preserved_attrs = {"forward_lower", "has_message_passing", "get_rcut", "get_ntypes", "get_dim_fparam", "get_dim_aparam", "is_aparam_nall", "get_type_map"}; + module = torch::jit::freeze(module, preserved_attrs); do_message_passing = module.run_method("has_message_passing").toBool(); torch::jit::FusionStrategy strategy; strategy = {{torch::jit::FusionBehavior::DYNAMIC, 10}}; diff --git a/source/api_cc/src/DeepSpinPT.cc b/source/api_cc/src/DeepSpinPT.cc index 1a245c7b2e..fac494772c 100644 --- a/source/api_cc/src/DeepSpinPT.cc +++ b/source/api_cc/src/DeepSpinPT.cc @@ -89,6 +89,8 @@ void DeepSpinPT::init(const std::string& model, std::unordered_map metadata = {{"type", ""}}; module = torch::jit::load(model, device, metadata); module.eval(); + const std::vector& preserved_attrs = {"forward_lower", "has_message_passing", "get_rcut", "get_ntypes", "get_dim_fparam", "get_dim_aparam", "is_aparam_nall", "get_type_map"}; + module = torch::jit::freeze(module, preserved_attrs); do_message_passing = module.run_method("has_message_passing").toBool(); torch::jit::FusionStrategy strategy; strategy = {{torch::jit::FusionBehavior::DYNAMIC, 10}}; From f4c9198b43b4fcbab15acbfa32189c80b1e29b5e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 22 Nov 2024 18:59:50 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/api_cc/src/DeepPotPT.cc | 5 ++++- source/api_cc/src/DeepSpinPT.cc | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 4f17c30e25..68da1cabd2 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -89,7 +89,10 @@ void DeepPotPT::init(const std::string& model, std::unordered_map metadata = {{"type", ""}}; module = torch::jit::load(model, device, metadata); module.eval(); - const std::vector& preserved_attrs = {"forward_lower", "has_message_passing", "get_rcut", "get_ntypes", "get_dim_fparam", "get_dim_aparam", "is_aparam_nall", "get_type_map"}; + const std::vector& preserved_attrs = { + "forward_lower", "has_message_passing", "get_rcut", + "get_ntypes", "get_dim_fparam", "get_dim_aparam", + "is_aparam_nall", "get_type_map"}; module = torch::jit::freeze(module, preserved_attrs); do_message_passing = module.run_method("has_message_passing").toBool(); torch::jit::FusionStrategy strategy; diff --git a/source/api_cc/src/DeepSpinPT.cc b/source/api_cc/src/DeepSpinPT.cc index fac494772c..1add9a4068 100644 --- a/source/api_cc/src/DeepSpinPT.cc +++ b/source/api_cc/src/DeepSpinPT.cc @@ -89,7 +89,10 @@ void DeepSpinPT::init(const std::string& model, std::unordered_map metadata = {{"type", ""}}; module = torch::jit::load(model, device, metadata); module.eval(); - const std::vector& preserved_attrs = {"forward_lower", "has_message_passing", "get_rcut", "get_ntypes", "get_dim_fparam", "get_dim_aparam", "is_aparam_nall", "get_type_map"}; + const std::vector& preserved_attrs = { + "forward_lower", "has_message_passing", "get_rcut", + "get_ntypes", "get_dim_fparam", "get_dim_aparam", + "is_aparam_nall", "get_type_map"}; module = torch::jit::freeze(module, preserved_attrs); do_message_passing = module.run_method("has_message_passing").toBool(); torch::jit::FusionStrategy strategy; From a4fc9c8a51626595fb0a36f6fc5f32977533e351 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 22 Nov 2024 14:18:22 -0500 Subject: [PATCH 4/5] Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks" This reverts commit f4c9198b43b4fcbab15acbfa32189c80b1e29b5e. --- source/api_cc/src/DeepPotPT.cc | 5 +---- source/api_cc/src/DeepSpinPT.cc | 5 +---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 68da1cabd2..4f17c30e25 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -89,10 +89,7 @@ void DeepPotPT::init(const std::string& model, std::unordered_map metadata = {{"type", ""}}; module = torch::jit::load(model, device, metadata); module.eval(); - const std::vector& preserved_attrs = { - "forward_lower", "has_message_passing", "get_rcut", - "get_ntypes", "get_dim_fparam", "get_dim_aparam", - "is_aparam_nall", "get_type_map"}; + const std::vector& preserved_attrs = {"forward_lower", "has_message_passing", "get_rcut", "get_ntypes", "get_dim_fparam", "get_dim_aparam", "is_aparam_nall", "get_type_map"}; module = torch::jit::freeze(module, preserved_attrs); do_message_passing = module.run_method("has_message_passing").toBool(); torch::jit::FusionStrategy strategy; diff --git a/source/api_cc/src/DeepSpinPT.cc b/source/api_cc/src/DeepSpinPT.cc index 1add9a4068..fac494772c 100644 --- a/source/api_cc/src/DeepSpinPT.cc +++ b/source/api_cc/src/DeepSpinPT.cc @@ -89,10 +89,7 @@ void DeepSpinPT::init(const std::string& model, std::unordered_map metadata = {{"type", ""}}; module = torch::jit::load(model, device, metadata); module.eval(); - const std::vector& preserved_attrs = { - "forward_lower", "has_message_passing", "get_rcut", - "get_ntypes", "get_dim_fparam", "get_dim_aparam", - "is_aparam_nall", "get_type_map"}; + const std::vector& preserved_attrs = {"forward_lower", "has_message_passing", "get_rcut", "get_ntypes", "get_dim_fparam", "get_dim_aparam", "is_aparam_nall", "get_type_map"}; module = torch::jit::freeze(module, preserved_attrs); do_message_passing = module.run_method("has_message_passing").toBool(); torch::jit::FusionStrategy strategy; From 530f98e995af66b71faefe5eba695185b27c7449 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 22 Nov 2024 14:18:23 -0500 Subject: [PATCH 5/5] Revert "Freeze TorchScript modules with preserved attributes" This reverts commit e4453720c71d80af0f92e196ab1bf9f69b4b254e. --- source/api_cc/src/DeepPotPT.cc | 2 -- source/api_cc/src/DeepSpinPT.cc | 2 -- 2 files changed, 4 deletions(-) diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 4f17c30e25..b431ad65cf 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -89,8 +89,6 @@ void DeepPotPT::init(const std::string& model, std::unordered_map metadata = {{"type", ""}}; module = torch::jit::load(model, device, metadata); module.eval(); - const std::vector& preserved_attrs = {"forward_lower", "has_message_passing", "get_rcut", "get_ntypes", "get_dim_fparam", "get_dim_aparam", "is_aparam_nall", "get_type_map"}; - module = torch::jit::freeze(module, preserved_attrs); do_message_passing = module.run_method("has_message_passing").toBool(); torch::jit::FusionStrategy strategy; strategy = {{torch::jit::FusionBehavior::DYNAMIC, 10}}; diff --git a/source/api_cc/src/DeepSpinPT.cc b/source/api_cc/src/DeepSpinPT.cc index fac494772c..1a245c7b2e 100644 --- a/source/api_cc/src/DeepSpinPT.cc +++ b/source/api_cc/src/DeepSpinPT.cc @@ -89,8 +89,6 @@ void DeepSpinPT::init(const std::string& model, std::unordered_map metadata = {{"type", ""}}; module = torch::jit::load(model, device, metadata); module.eval(); - const std::vector& preserved_attrs = {"forward_lower", "has_message_passing", "get_rcut", "get_ntypes", "get_dim_fparam", "get_dim_aparam", "is_aparam_nall", "get_type_map"}; - module = torch::jit::freeze(module, preserved_attrs); do_message_passing = module.run_method("has_message_passing").toBool(); torch::jit::FusionStrategy strategy; strategy = {{torch::jit::FusionBehavior::DYNAMIC, 10}};