From 2e3e0214da7eb7e260e31c65481bf08791b7a7fa Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 30 Jun 2026 01:50:53 +0800 Subject: [PATCH 1/3] test(consistent): use curated cases throughout --- .../consistent/descriptor/test_se_atten_v2.py | 109 ++++++++++++------ .../tests/consistent/descriptor/test_se_r.py | 50 ++++++-- .../tests/consistent/descriptor/test_se_t.py | 51 ++++++-- .../consistent/descriptor/test_se_t_tebd.py | 80 ++++++++----- .../tests/consistent/fitting/test_dipole.py | 41 ++++++- source/tests/consistent/fitting/test_dos.py | 52 +++++++-- .../consistent/fitting/test_dpa4_ener.py | 13 ++- source/tests/consistent/fitting/test_ener.py | 68 ++++++++--- source/tests/consistent/fitting/test_polar.py | 32 ++++- .../tests/consistent/fitting/test_property.py | 57 +++++++-- source/tests/consistent/loss/test_dos.py | 12 +- source/tests/consistent/loss/test_ener.py | 52 +++++++-- .../tests/consistent/loss/test_ener_spin.py | 15 ++- source/tests/consistent/loss/test_property.py | 13 ++- source/tests/consistent/loss/test_tensor.py | 12 +- source/tests/consistent/model/test_dipole.py | 13 ++- source/tests/consistent/model/test_dos.py | 13 ++- source/tests/consistent/model/test_dpa1.py | 12 +- source/tests/consistent/model/test_ener.py | 52 ++++----- source/tests/consistent/model/test_frozen.py | 10 +- .../consistent/model/test_linear_ener.py | 20 ++-- source/tests/consistent/model/test_polar.py | 13 ++- .../tests/consistent/model/test_property.py | 13 ++- .../tests/consistent/model/test_zbl_ener.py | 28 ++--- source/tests/consistent/test_activation.py | 24 ++-- source/tests/consistent/test_learning_rate.py | 15 ++- .../tests/consistent/test_type_embedding.py | 47 ++++++-- 27 files changed, 663 insertions(+), 254 deletions(-) diff --git a/source/tests/consistent/descriptor/test_se_atten_v2.py b/source/tests/consistent/descriptor/test_se_atten_v2.py index 1cc644c73c..cc4ef67f9c 100644 --- a/source/tests/consistent/descriptor/test_se_atten_v2.py +++ b/source/tests/consistent/descriptor/test_se_atten_v2.py @@ -21,7 +21,7 @@ INSTALLED_PT, INSTALLED_PT_EXPT, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( DescriptorAPITest, @@ -63,26 +63,77 @@ descrpt_se_atten_args, ) +SE_ATTEN_V2_CASE_FIELDS = ( + "tebd_dim", + "resnet_dt", + "type_one_side", + "attn", + "attn_layer", + "attn_dotr", + "excluded_types", + "env_protection", + "set_davg_zero", + "scaling_factor", + "normalize", + "temperature", + "ln_eps", + "concat_output_tebd", + "precision", + "use_econf_tebd", + "use_tebd_bias", +) -@parameterized( - (4,), # tebd_dim - (True,), # resnet_dt - (True, False), # type_one_side - (20,), # attn - (0, 2), # attn_layer - (True, False), # attn_dotr - ([], [[0, 1]]), # excluded_types - (0.0,), # env_protection - (True, False), # set_davg_zero - (1.0,), # scaling_factor - (True, False), # normalize - (None, 1.0), # temperature - (1e-5,), # ln_eps - (True,), # concat_output_tebd - ("float64",), # precision - (True, False), # use_econf_tebd - (False,), # use_tebd_bias +SE_ATTEN_V2_BASELINE_CASE = { + "tebd_dim": 4, + "resnet_dt": True, + "type_one_side": True, + "attn": 20, + "attn_layer": 2, + "attn_dotr": True, + "excluded_types": [], + "env_protection": 0.0, + "set_davg_zero": True, + "scaling_factor": 1.0, + "normalize": True, + "temperature": None, + "ln_eps": 1e-5, + "concat_output_tebd": True, + "precision": "float64", + "use_econf_tebd": True, + "use_tebd_bias": False, +} + + +def se_atten_v2_case(**overrides: Any) -> tuple: + case = SE_ATTEN_V2_BASELINE_CASE | overrides + return tuple(case[field] for field in SE_ATTEN_V2_CASE_FIELDS) + + +SE_ATTEN_V2_CURATED_CASES = ( + se_atten_v2_case(), + se_atten_v2_case(type_one_side=False), + se_atten_v2_case(attn_layer=0, attn_dotr=False, normalize=False), + se_atten_v2_case(attn_dotr=False), + se_atten_v2_case(excluded_types=[[0, 1]]), + se_atten_v2_case(set_davg_zero=False), + se_atten_v2_case(normalize=False), + se_atten_v2_case(temperature=1.0), + se_atten_v2_case(use_econf_tebd=False), + se_atten_v2_case( + type_one_side=False, + attn_dotr=False, + excluded_types=[[0, 1]], + set_davg_zero=False, + normalize=False, + temperature=1.0, + use_econf_tebd=False, + ), ) + +SE_ATTEN_V2_DESCRIPTOR_API_CURATED_CASES = SE_ATTEN_V2_CURATED_CASES + + +@parameterized_cases(*SE_ATTEN_V2_CURATED_CASES) class TestSeAttenV2(CommonTest, DescriptorTest, unittest.TestCase): @property def data(self) -> dict: @@ -497,25 +548,7 @@ def atol(self) -> float: raise ValueError(f"Unknown precision: {precision}") -@parameterized( - (4,), # tebd_dim - (True,), # resnet_dt - (True, False), # type_one_side - (20,), # attn - (0, 2), # attn_layer - (True, False), # attn_dotr - ([], [[0, 1]]), # excluded_types - (0.0,), # env_protection - (True, False), # set_davg_zero - (1.0,), # scaling_factor - (True, False), # normalize - (None, 1.0), # temperature - (1e-5,), # ln_eps - (True,), # concat_output_tebd - ("float64",), # precision - (True, False), # use_econf_tebd - (False,), # use_tebd_bias -) +@parameterized_cases(*SE_ATTEN_V2_DESCRIPTOR_API_CURATED_CASES) class TestSeAttenV2DescriptorAPI(DescriptorAPITest, unittest.TestCase): """Test consistency of BaseDescriptor API methods across backends.""" diff --git a/source/tests/consistent/descriptor/test_se_r.py b/source/tests/consistent/descriptor/test_se_r.py index 826eaf2145..7d002cd955 100644 --- a/source/tests/consistent/descriptor/test_se_r.py +++ b/source/tests/consistent/descriptor/test_se_r.py @@ -18,7 +18,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( DescriptorAPITest, @@ -53,12 +53,43 @@ DescrptSeRArrayAPIStrict = None -@parameterized( - (True, False), # resnet_dt - (True, False), # type_one_side - ([], [[0, 1]]), # excluded_types - ("float32", "float64"), # precision +SE_R_CASE_FIELDS = ( + "resnet_dt", + "type_one_side", + "excluded_types", + "precision", ) + +SE_R_BASELINE_CASE = { + "resnet_dt": True, + "type_one_side": True, + "excluded_types": [], + "precision": "float64", +} + + +def se_r_case(**overrides: Any) -> tuple: + case = SE_R_BASELINE_CASE | overrides + return tuple(case[field] for field in SE_R_CASE_FIELDS) + + +SE_R_CURATED_CASES = ( + se_r_case(), + se_r_case(resnet_dt=False), + se_r_case(type_one_side=False), + se_r_case(excluded_types=[[0, 1]]), + se_r_case(precision="float32"), +) + +SE_R_DESCRIPTOR_API_CURATED_CASES = ( + se_r_case(), + se_r_case(resnet_dt=False), + se_r_case(type_one_side=False), + se_r_case(excluded_types=[[0, 1]]), +) + + +@parameterized_cases(*SE_R_CURATED_CASES) class TestSeR(CommonTest, DescriptorTest, unittest.TestCase): @property def data(self) -> dict: @@ -264,12 +295,7 @@ def atol(self) -> float: raise ValueError(f"Unknown precision: {precision}") -@parameterized( - (True, False), # resnet_dt - (True, False), # type_one_side - ([], [[0, 1]]), # excluded_types - ("float64",), # precision -) +@parameterized_cases(*SE_R_DESCRIPTOR_API_CURATED_CASES) class TestSeRDescriptorAPI(DescriptorAPITest, unittest.TestCase): """Test consistency of BaseDescriptor API methods across backends.""" diff --git a/source/tests/consistent/descriptor/test_se_t.py b/source/tests/consistent/descriptor/test_se_t.py index 7d2a33aba9..bacc62eb7f 100644 --- a/source/tests/consistent/descriptor/test_se_t.py +++ b/source/tests/consistent/descriptor/test_se_t.py @@ -18,7 +18,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( DescriptorAPITest, @@ -49,13 +49,45 @@ descrpt_se_t_args, ) +SE_T_CASE_FIELDS = ( + "resnet_dt", + "excluded_types", + "precision", + "env_protection", +) + +SE_T_BASELINE_CASE = { + "resnet_dt": True, + "excluded_types": [], + "precision": "float64", + "env_protection": 0.0, +} + + +def se_t_case(**overrides: Any) -> tuple: + case = SE_T_BASELINE_CASE | overrides + return tuple(case[field] for field in SE_T_CASE_FIELDS) + -@parameterized( - (True, False), # resnet_dt - ([], [[0, 1]]), # excluded_types - ("float32", "float64"), # precision - (0.0, 1e-8, 1e-2), # env_protection +SE_T_CURATED_CASES = ( + se_t_case(), + se_t_case(resnet_dt=False), + se_t_case(excluded_types=[[0, 1]]), + se_t_case(precision="float32"), + se_t_case(env_protection=1e-8), + se_t_case(env_protection=1e-2), ) + +SE_T_DESCRIPTOR_API_CURATED_CASES = ( + se_t_case(), + se_t_case(resnet_dt=False), + se_t_case(excluded_types=[[0, 1]]), + se_t_case(env_protection=1e-8), + se_t_case(env_protection=1e-2), +) + + +@parameterized_cases(*SE_T_CURATED_CASES) class TestSeT(CommonTest, DescriptorTest, unittest.TestCase): @property def data(self) -> dict: @@ -264,12 +296,7 @@ def atol(self) -> float: raise ValueError(f"Unknown precision: {precision}") -@parameterized( - (True, False), # resnet_dt - ([], [[0, 1]]), # excluded_types - ("float64",), # precision - (0.0, 1e-8, 1e-2), # env_protection -) +@parameterized_cases(*SE_T_DESCRIPTOR_API_CURATED_CASES) class TestSeTDescriptorAPI(DescriptorAPITest, unittest.TestCase): """Test consistency of BaseDescriptor API methods across backends.""" diff --git a/source/tests/consistent/descriptor/test_se_t_tebd.py b/source/tests/consistent/descriptor/test_se_t_tebd.py index 4017e059f5..518907296d 100644 --- a/source/tests/consistent/descriptor/test_se_t_tebd.py +++ b/source/tests/consistent/descriptor/test_se_t_tebd.py @@ -21,7 +21,7 @@ INSTALLED_PT, INSTALLED_PT_EXPT, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( DescriptorAPITest, @@ -57,20 +57,60 @@ descrpt_se_e3_tebd_args, ) +SE_T_TEBD_CASE_FIELDS = ( + "tebd_dim", + "tebd_input_mode", + "resnet_dt", + "excluded_types", + "env_protection", + "set_davg_zero", + "smooth", + "concat_output_tebd", + "precision", + "use_econf_tebd", + "use_tebd_bias", +) + +SE_T_TEBD_BASELINE_CASE = { + "tebd_dim": 4, + "tebd_input_mode": "strip", + "resnet_dt": True, + "excluded_types": [], + "env_protection": 0.0, + "set_davg_zero": True, + "smooth": True, + "concat_output_tebd": True, + "precision": "float64", + "use_econf_tebd": True, + "use_tebd_bias": False, +} + + +def se_t_tebd_case(**overrides: Any) -> tuple: + case = SE_T_TEBD_BASELINE_CASE | overrides + return tuple(case[field] for field in SE_T_TEBD_CASE_FIELDS) + -@parameterized( - (4,), # tebd_dim - ("strip",), # tebd_input_mode - (True,), # resnet_dt - ([], [[0, 1]]), # excluded_types - (0.0,), # env_protection - (True, False), # set_davg_zero - (True, False), # smooth - (True,), # concat_output_tebd - ("float64",), # precision - (True, False), # use_econf_tebd - (False, True), # use_tebd_bias +SE_T_TEBD_CURATED_CASES = ( + se_t_tebd_case(), + se_t_tebd_case(excluded_types=[[0, 1]]), + se_t_tebd_case(set_davg_zero=False), + se_t_tebd_case(smooth=False), + se_t_tebd_case(use_econf_tebd=False), + se_t_tebd_case(use_tebd_bias=True), + se_t_tebd_case( + excluded_types=[[0, 1]], + set_davg_zero=False, + smooth=False, + use_econf_tebd=False, + use_tebd_bias=True, + ), ) + +SE_T_TEBD_DESCRIPTOR_API_CURATED_CASES = SE_T_TEBD_CURATED_CASES + + +@parameterized_cases(*SE_T_TEBD_CURATED_CASES) class TestSeTTebd(CommonTest, DescriptorTest, unittest.TestCase): @property def data(self) -> dict: @@ -357,19 +397,7 @@ def atol(self) -> float: raise ValueError(f"Unknown precision: {precision}") -@parameterized( - (4,), # tebd_dim - ("strip",), # tebd_input_mode - (True,), # resnet_dt - ([], [[0, 1]]), # excluded_types - (0.0,), # env_protection - (True, False), # set_davg_zero - (True, False), # smooth - (True,), # concat_output_tebd - ("float64",), # precision - (True, False), # use_econf_tebd - (False, True), # use_tebd_bias -) +@parameterized_cases(*SE_T_TEBD_DESCRIPTOR_API_CURATED_CASES) class TestSeTTebdDescriptorAPI(DescriptorAPITest, unittest.TestCase): """Test consistency of BaseDescriptor API methods across backends.""" diff --git a/source/tests/consistent/fitting/test_dipole.py b/source/tests/consistent/fitting/test_dipole.py index 245744a93e..736483df54 100644 --- a/source/tests/consistent/fitting/test_dipole.py +++ b/source/tests/consistent/fitting/test_dipole.py @@ -21,7 +21,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( DipoleFittingTest, @@ -64,13 +64,42 @@ fitting_dipole, ) +DIPOLE_FITTING_CASE_FIELDS = ( + "resnet_dt", + "precision", + "mixed_types", + "sel_type", +) + +DIPOLE_FITTING_BASELINE_CASE = { + "resnet_dt": True, + "precision": "float64", + "mixed_types": True, + "sel_type": None, +} + + +def dipole_fitting_case(**overrides: Any) -> tuple: + case = DIPOLE_FITTING_BASELINE_CASE | overrides + return tuple(case[field] for field in DIPOLE_FITTING_CASE_FIELDS) -@parameterized( - (True, False), # resnet_dt - ("float64", "float32"), # precision - (True, False), # mixed_types - (None, [0]), # sel_type + +DIPOLE_FITTING_CURATED_CASES = ( + dipole_fitting_case(), + dipole_fitting_case(resnet_dt=False), + dipole_fitting_case(precision="float32"), + dipole_fitting_case(mixed_types=False), + dipole_fitting_case(sel_type=[0]), + dipole_fitting_case( + resnet_dt=False, + precision="float32", + mixed_types=False, + sel_type=[0], + ), ) + + +@parameterized_cases(*DIPOLE_FITTING_CURATED_CASES) class TestDipole(CommonTest, DipoleFittingTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/fitting/test_dos.py b/source/tests/consistent/fitting/test_dos.py index f758c9d317..da42ea696d 100644 --- a/source/tests/consistent/fitting/test_dos.py +++ b/source/tests/consistent/fitting/test_dos.py @@ -21,7 +21,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( FittingTest, @@ -62,14 +62,50 @@ DOSFittingStrict = object -@parameterized( - (True, False), # resnet_dt - ("float64", "float32"), # precision - (True, False), # mixed_types - (0, 1), # numb_fparam - (0, 1), # numb_aparam - (10, 20), # numb_dos +DOS_FITTING_CASE_FIELDS = ( + "resnet_dt", + "precision", + "mixed_types", + "numb_fparam", + "numb_aparam", + "numb_dos", ) + +DOS_FITTING_BASELINE_CASE = { + "resnet_dt": True, + "precision": "float64", + "mixed_types": True, + "numb_fparam": 0, + "numb_aparam": 0, + "numb_dos": 10, +} + + +def dos_fitting_case(**overrides: Any) -> tuple: + case = DOS_FITTING_BASELINE_CASE | overrides + return tuple(case[field] for field in DOS_FITTING_CASE_FIELDS) + + +DOS_FITTING_CURATED_CASES = ( + dos_fitting_case(), + dos_fitting_case(resnet_dt=False), + dos_fitting_case(precision="float32"), + dos_fitting_case(mixed_types=False), + dos_fitting_case(numb_fparam=1), + dos_fitting_case(numb_aparam=1), + dos_fitting_case(numb_dos=20), + dos_fitting_case( + resnet_dt=False, + precision="float32", + mixed_types=False, + numb_fparam=1, + numb_aparam=1, + numb_dos=20, + ), +) + + +@parameterized_cases(*DOS_FITTING_CURATED_CASES) class TestDOS(CommonTest, FittingTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/fitting/test_dpa4_ener.py b/source/tests/consistent/fitting/test_dpa4_ener.py index 64b7918919..ee10b54343 100644 --- a/source/tests/consistent/fitting/test_dpa4_ener.py +++ b/source/tests/consistent/fitting/test_dpa4_ener.py @@ -18,7 +18,7 @@ INSTALLED_PT, INSTALLED_PT_EXPT, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( FittingTest, @@ -45,10 +45,15 @@ SeZMEnerFittingTF = None -@parameterized( - ("float64", "float32"), # precision - ([0], [16, 16]), # neuron ([0] = auto-width placeholder) +DPA4_ENER_FITTING_CURATED_CASES = ( + ("float64", [0]), + ("float64", [16, 16]), + ("float32", [0]), + ("float32", [16, 16]), ) + + +@parameterized_cases(*DPA4_ENER_FITTING_CURATED_CASES) class TestDPA4Ener(CommonTest, FittingTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/fitting/test_ener.py b/source/tests/consistent/fitting/test_ener.py index ba0f68c163..ceccb6aa6e 100644 --- a/source/tests/consistent/fitting/test_ener.py +++ b/source/tests/consistent/fitting/test_ener.py @@ -22,7 +22,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( FittingTest, @@ -74,14 +74,57 @@ EnerFittingStrict = None -@parameterized( - (True, False), # resnet_dt - ("float64", "float32", "bfloat16"), # precision - (True, False), # mixed_types - ((0, None), (1, None), (1, [1.0])), # (numb_fparam, default_fparam) - ((0, False), (1, False), (1, True)), # (numb_aparam, use_aparam_as_mask) - ([], [-12345.6, None]), # atom_ener +ENER_FITTING_CASE_FIELDS = ( + "resnet_dt", + "precision", + "mixed_types", + "fparam", + "aparam", + "atom_ener", ) + +ENER_FITTING_BASELINE_CASE = { + "resnet_dt": True, + "precision": "float64", + "mixed_types": True, + "fparam": (0, None), + "aparam": (0, False), + "atom_ener": [], +} + + +def ener_fitting_case(**overrides: Any) -> tuple: + case = ENER_FITTING_BASELINE_CASE | overrides + return tuple(case[field] for field in ENER_FITTING_CASE_FIELDS) + + +ENER_FITTING_CURATED_CASES = ( + ener_fitting_case(), + ener_fitting_case(resnet_dt=False), + ener_fitting_case(precision="float32"), + ener_fitting_case(precision="bfloat16"), + ener_fitting_case(mixed_types=False), + ener_fitting_case(fparam=(1, None)), + ener_fitting_case(fparam=(1, [1.0])), + ener_fitting_case(aparam=(1, False)), + ener_fitting_case(aparam=(1, True)), + ener_fitting_case(atom_ener=[-12345.6, None]), + ener_fitting_case( + resnet_dt=False, + precision="float32", + mixed_types=False, + fparam=(1, [1.0]), + aparam=(1, True), + atom_ener=[-12345.6, None], + ), +) + +ENER_FITTING_STAT_CURATED_CASES = ( + ener_fitting_case(fparam=(3, None), aparam=(3, False)), +) + + +@parameterized_cases(*ENER_FITTING_CURATED_CASES) class TestEner(CommonTest, FittingTest, unittest.TestCase): @property def data(self) -> dict: @@ -420,14 +463,7 @@ def atol(self) -> float: raise ValueError(f"Unknown precision: {precision}") -@parameterized( - (True,), # resnet_dt - ("float64",), # precision - (True,), # mixed_types - ((3, None),), # (numb_fparam, default_fparam) - ((3, False),), # (numb_aparam, use_aparam_as_mask) - ([],), # atom_ener -) +@parameterized_cases(*ENER_FITTING_STAT_CURATED_CASES) class TestEnerStat(CommonTest, FittingTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/fitting/test_polar.py b/source/tests/consistent/fitting/test_polar.py index 142cbefdc8..a9b20b1d96 100644 --- a/source/tests/consistent/fitting/test_polar.py +++ b/source/tests/consistent/fitting/test_polar.py @@ -21,7 +21,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( DipoleFittingTest, @@ -64,12 +64,34 @@ fitting_polar, ) +POLAR_FITTING_CASE_FIELDS = ( + "resnet_dt", + "precision", + "mixed_types", +) + +POLAR_FITTING_BASELINE_CASE = { + "resnet_dt": True, + "precision": "float64", + "mixed_types": True, +} + + +def polar_fitting_case(**overrides: Any) -> tuple: + case = POLAR_FITTING_BASELINE_CASE | overrides + return tuple(case[field] for field in POLAR_FITTING_CASE_FIELDS) -@parameterized( - (True, False), # resnet_dt - ("float64", "float32"), # precision - (True, False), # mixed_types + +POLAR_FITTING_CURATED_CASES = ( + polar_fitting_case(), + polar_fitting_case(resnet_dt=False), + polar_fitting_case(precision="float32"), + polar_fitting_case(mixed_types=False), + polar_fitting_case(resnet_dt=False, precision="float32", mixed_types=False), ) + + +@parameterized_cases(*POLAR_FITTING_CURATED_CASES) class TestPolar(CommonTest, DipoleFittingTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/fitting/test_property.py b/source/tests/consistent/fitting/test_property.py index a9da348410..792808e4a3 100644 --- a/source/tests/consistent/fitting/test_property.py +++ b/source/tests/consistent/fitting/test_property.py @@ -25,7 +25,7 @@ INSTALLED_PT, INSTALLED_PT_EXPT, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( FittingTest, @@ -64,15 +64,54 @@ PropertyFittingTF = object -@parameterized( - (True, False), # resnet_dt - ("float64", "float32"), # precision - (True, False), # mixed_types - (0, 1), # numb_fparam - (0, 1), # numb_aparam - (1, 3), # task_dim - (True, False), # intensive +PROPERTY_FITTING_CASE_FIELDS = ( + "resnet_dt", + "precision", + "mixed_types", + "numb_fparam", + "numb_aparam", + "task_dim", + "intensive", ) + +PROPERTY_FITTING_BASELINE_CASE = { + "resnet_dt": True, + "precision": "float64", + "mixed_types": True, + "numb_fparam": 0, + "numb_aparam": 0, + "task_dim": 1, + "intensive": True, +} + + +def property_fitting_case(**overrides: Any) -> tuple: + case = PROPERTY_FITTING_BASELINE_CASE | overrides + return tuple(case[field] for field in PROPERTY_FITTING_CASE_FIELDS) + + +PROPERTY_FITTING_CURATED_CASES = ( + property_fitting_case(), + property_fitting_case(resnet_dt=False), + property_fitting_case(precision="float32"), + property_fitting_case(mixed_types=False), + property_fitting_case(numb_fparam=1), + property_fitting_case(numb_aparam=1), + property_fitting_case(task_dim=3), + property_fitting_case(intensive=False), + property_fitting_case( + resnet_dt=False, + precision="float32", + mixed_types=False, + numb_fparam=1, + numb_aparam=1, + task_dim=3, + intensive=False, + ), +) + + +@parameterized_cases(*PROPERTY_FITTING_CURATED_CASES) class TestProperty(CommonTest, FittingTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/loss/test_dos.py b/source/tests/consistent/loss/test_dos.py index 8ed91873ec..3974c23a0f 100644 --- a/source/tests/consistent/loss/test_dos.py +++ b/source/tests/consistent/loss/test_dos.py @@ -20,7 +20,7 @@ INSTALLED_PT, INSTALLED_PT_EXPT, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( LossTest, @@ -44,10 +44,14 @@ import array_api_strict -@parameterized( - (1.0, 0.0), # pref_dos - (1.0, 0.0), # pref_ados +DOS_LOSS_CURATED_CASES = ( + (1.0, 1.0), + (1.0, 0.0), + (0.0, 1.0), ) + + +@parameterized_cases(*DOS_LOSS_CURATED_CASES) class TestDOS(CommonTest, LossTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/loss/test_ener.py b/source/tests/consistent/loss/test_ener.py index c87e7409cb..8e95cc0590 100644 --- a/source/tests/consistent/loss/test_ener.py +++ b/source/tests/consistent/loss/test_ener.py @@ -22,7 +22,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( LossTest, @@ -61,14 +61,50 @@ import array_api_strict -@parameterized( - (False, True), # huber - (False, True), # enable_atom_ener_coeff - ("mse", "mae"), # loss_func - (False, True), # f_use_norm - (False, True), # mae (dp test extra MAE metrics) - (False, True), # intensive_ener_virial +ENER_LOSS_CASE_FIELDS = ( + "use_huber", + "enable_atom_ener_coeff", + "loss_func", + "f_use_norm", + "mae", + "intensive_ener_virial", ) + +ENER_LOSS_BASELINE_CASE = { + "use_huber": False, + "enable_atom_ener_coeff": False, + "loss_func": "mse", + "f_use_norm": False, + "mae": False, + "intensive_ener_virial": False, +} + + +def ener_loss_case(**overrides: Any) -> tuple: + case = ENER_LOSS_BASELINE_CASE | overrides + return tuple(case[field] for field in ENER_LOSS_CASE_FIELDS) + + +ENER_LOSS_CURATED_CASES = ( + ener_loss_case(), + ener_loss_case(use_huber=True), + ener_loss_case(enable_atom_ener_coeff=True), + ener_loss_case(loss_func="mae"), + ener_loss_case(f_use_norm=True), + ener_loss_case(mae=True), + ener_loss_case(intensive_ener_virial=True), + ener_loss_case( + use_huber=True, + enable_atom_ener_coeff=True, + loss_func="mae", + f_use_norm=True, + mae=True, + intensive_ener_virial=True, + ), +) + + +@parameterized_cases(*ENER_LOSS_CURATED_CASES) class TestEner(CommonTest, LossTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/loss/test_ener_spin.py b/source/tests/consistent/loss/test_ener_spin.py index bd6561bb78..e3ab9a8f13 100644 --- a/source/tests/consistent/loss/test_ener_spin.py +++ b/source/tests/consistent/loss/test_ener_spin.py @@ -20,7 +20,7 @@ INSTALLED_PT, INSTALLED_PT_EXPT, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( LossTest, @@ -44,11 +44,16 @@ import array_api_strict -@parameterized( - ("mse", "mae"), # loss_func - (False, True), # mae (dp test extra MAE metrics) - (False, True), # intensive_ener_virial +ENER_SPIN_LOSS_CURATED_CASES = ( + ("mse", False, False), + ("mae", False, False), + ("mse", True, False), + ("mse", False, True), + ("mae", False, True), ) + + +@parameterized_cases(*ENER_SPIN_LOSS_CURATED_CASES) class TestEnerSpin(CommonTest, LossTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/loss/test_property.py b/source/tests/consistent/loss/test_property.py index 7750eb6dae..6edae664c8 100644 --- a/source/tests/consistent/loss/test_property.py +++ b/source/tests/consistent/loss/test_property.py @@ -20,7 +20,7 @@ INSTALLED_PT, INSTALLED_PT_EXPT, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( LossTest, @@ -44,9 +44,16 @@ import array_api_strict -@parameterized( - ("smooth_mae", "mae", "mse", "rmse", "mape"), # loss_func +PROPERTY_LOSS_CURATED_CASES = ( + ("smooth_mae",), + ("mae",), + ("mse",), + ("rmse",), + ("mape",), ) + + +@parameterized_cases(*PROPERTY_LOSS_CURATED_CASES) class TestProperty(CommonTest, LossTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/loss/test_tensor.py b/source/tests/consistent/loss/test_tensor.py index 06feb908fa..ba23063e15 100644 --- a/source/tests/consistent/loss/test_tensor.py +++ b/source/tests/consistent/loss/test_tensor.py @@ -20,7 +20,7 @@ INSTALLED_PT, INSTALLED_PT_EXPT, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( LossTest, @@ -44,10 +44,14 @@ import array_api_strict -@parameterized( - (1.0, 0.0), # pref - (1.0, 0.0), # pref_atomic +TENSOR_LOSS_CURATED_CASES = ( + (1.0, 1.0), + (1.0, 0.0), + (0.0, 1.0), ) + + +@parameterized_cases(*TENSOR_LOSS_CURATED_CASES) class TestTensor(CommonTest, LossTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/model/test_dipole.py b/source/tests/consistent/model/test_dipole.py index 8b9c24cd57..30507caf13 100644 --- a/source/tests/consistent/model/test_dipole.py +++ b/source/tests/consistent/model/test_dipole.py @@ -28,7 +28,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( ModelTest, @@ -1224,10 +1224,15 @@ def test_get_observed_type_list(self) -> None: self.assertEqual(dp_observed, ["O"]) -@parameterized( - (([], []), ([[0, 1]], [1])), # (pair_exclude_types, atom_exclude_types) - (False, True), # fparam_in_data +MODEL_STAT_CURATED_CASES = ( + (([], []), False), + (([], []), True), + (([[0, 1]], [1]), False), + (([[0, 1]], [1]), True), ) + + +@parameterized_cases(*MODEL_STAT_CURATED_CASES) @unittest.skipUnless(INSTALLED_PT and INSTALLED_PT_EXPT, "PT and PT_EXPT are required") class TestDipoleComputeOrLoadStat(unittest.TestCase): """Test that compute_or_load_stat produces identical statistics on dp, pt, and pt_expt. diff --git a/source/tests/consistent/model/test_dos.py b/source/tests/consistent/model/test_dos.py index 016b4ffc04..bfe19dda53 100644 --- a/source/tests/consistent/model/test_dos.py +++ b/source/tests/consistent/model/test_dos.py @@ -28,7 +28,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( ModelTest, @@ -1213,10 +1213,15 @@ def test_get_observed_type_list(self) -> None: self.assertEqual(dp_observed, ["O"]) -@parameterized( - (([], []), ([[0, 1]], [1])), # (pair_exclude_types, atom_exclude_types) - (False, True), # fparam_in_data +MODEL_STAT_CURATED_CASES = ( + (([], []), False), + (([], []), True), + (([[0, 1]], [1]), False), + (([[0, 1]], [1]), True), ) + + +@parameterized_cases(*MODEL_STAT_CURATED_CASES) @unittest.skipUnless(INSTALLED_PT and INSTALLED_PT_EXPT, "PT and PT_EXPT are required") class TestDOSComputeOrLoadStat(unittest.TestCase): """Test that compute_or_load_stat produces identical statistics on dp, pt, and pt_expt. diff --git a/source/tests/consistent/model/test_dpa1.py b/source/tests/consistent/model/test_dpa1.py index b32570d024..e367f988b1 100644 --- a/source/tests/consistent/model/test_dpa1.py +++ b/source/tests/consistent/model/test_dpa1.py @@ -20,7 +20,7 @@ INSTALLED_TF, SKIP_FLAG, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( ModelTest, @@ -55,11 +55,13 @@ EnergyModelJAX = None -@parameterized( - ("strip", "concat"), # tebd_input_mode - # strip + smooth is inconsistent - (False,), # smooth +DPA1_ENER_CURATED_CASES = ( + ("strip", False), + ("concat", False), ) + + +@parameterized_cases(*DPA1_ENER_CURATED_CASES) class TestDPA1Ener(CommonTest, ModelTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index d62f84bea8..137c3ffc74 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -30,7 +30,7 @@ INSTALLED_TF, SKIP_FLAG, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( ModelTest, @@ -74,16 +74,28 @@ EnergyModelJAX = None -@parameterized( - ( - [], - [[0, 1]], - ), - ( - [], - [1], - ), +MODEL_EXCLUSION_CURATED_CASES = ( + ([], []), + ([], [1]), + ([[0, 1]], []), + ([[0, 1]], [1]), ) + +MODEL_STAT_CURATED_CASES = ( + (([], []), False), + (([], []), True), + (([[0, 1]], [1]), False), + (([[0, 1]], [1]), True), +) + +ENER_CHG_SPIN_EBD_FPARAM_CURATED_CASES = ( + ("no_chg_spin",), + ("explicit_chg_spin",), + ("default_chg_spin",), +) + + +@parameterized_cases(*MODEL_EXCLUSION_CURATED_CASES) class TestEner(CommonTest, ModelTest, unittest.TestCase): @property def data(self) -> dict: @@ -302,16 +314,7 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: raise ValueError(f"Unknown backend: {backend}") -@parameterized( - ( - [], - [[0, 1]], - ), - ( - [], - [1], - ), -) +@parameterized_cases(*MODEL_EXCLUSION_CURATED_CASES) class TestEnerLower(CommonTest, ModelTest, unittest.TestCase): @property def data(self) -> dict: @@ -1669,10 +1672,7 @@ def test_get_observed_type_list(self) -> None: self.assertEqual(dp_observed, ["O"]) -@parameterized( - (([], []), ([[0, 1]], [1])), # (pair_exclude_types, atom_exclude_types) - (False, True), # fparam_in_data -) +@parameterized_cases(*MODEL_STAT_CURATED_CASES) @unittest.skipUnless(INSTALLED_PT and INSTALLED_PT_EXPT, "PT and PT_EXPT are required") class TestEnerComputeOrLoadStat(unittest.TestCase): """Test that compute_or_load_stat produces identical statistics on dp, pt, and pt_expt. @@ -2015,9 +2015,7 @@ def raise_error(): compare_variables_recursive(dp_ser_loaded, pe_ser_loaded) -@parameterized( - ("no_chg_spin", "explicit_chg_spin", "default_chg_spin"), # cs_mode -) +@parameterized_cases(*ENER_CHG_SPIN_EBD_FPARAM_CURATED_CASES) @unittest.skipUnless(INSTALLED_PT and INSTALLED_PT_EXPT, "PT and PT_EXPT are required") class TestEnerChgSpinEbdFparam(unittest.TestCase): """Test dp/pt/pt_expt model forward consistency for add_chg_spin_ebd with three modes. diff --git a/source/tests/consistent/model/test_frozen.py b/source/tests/consistent/model/test_frozen.py index c19b1a8cb4..4ab62c699c 100644 --- a/source/tests/consistent/model/test_frozen.py +++ b/source/tests/consistent/model/test_frozen.py @@ -18,7 +18,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( ModelTest, @@ -50,6 +50,12 @@ tf_model = "deeppot_for_consistent_frozen.pb" dp_model = "deeppot_for_consistent_frozen.dp" +FROZEN_MODEL_CURATED_CASES = ( + (pt_model,), + (tf_model,), + (dp_model,), +) + def setUpModule() -> None: case = get_cases()["se_e2_a"] @@ -72,7 +78,7 @@ def tearDownModule() -> None: os.remove(tmp_pb) -@parameterized((pt_model, tf_model, dp_model)) +@parameterized_cases(*FROZEN_MODEL_CURATED_CASES) class TestFrozen(CommonTest, ModelTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/model/test_linear_ener.py b/source/tests/consistent/model/test_linear_ener.py index 9583a03668..1d3708f31b 100644 --- a/source/tests/consistent/model/test_linear_ener.py +++ b/source/tests/consistent/model/test_linear_ener.py @@ -15,7 +15,7 @@ INSTALLED_PT_EXPT, SKIP_FLAG, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( ModelTest, @@ -41,17 +41,15 @@ model_args, ) - -@parameterized( - ( - [], - [[0, 1]], - ), - ( - [], - [1], - ), +LINEAR_ENER_MODEL_EXCLUSION_CURATED_CASES = ( + ([], []), + ([], [1]), + ([[0, 1]], []), + ([[0, 1]], [1]), ) + + +@parameterized_cases(*LINEAR_ENER_MODEL_EXCLUSION_CURATED_CASES) class TestLinearEner(CommonTest, ModelTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/model/test_polar.py b/source/tests/consistent/model/test_polar.py index 4fe3a2c6df..443fd62b97 100644 --- a/source/tests/consistent/model/test_polar.py +++ b/source/tests/consistent/model/test_polar.py @@ -28,7 +28,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( ModelTest, @@ -1218,10 +1218,15 @@ def test_get_observed_type_list(self) -> None: self.assertEqual(dp_observed, ["O"]) -@parameterized( - (([], []), ([[0, 1]], [1])), # (pair_exclude_types, atom_exclude_types) - (False, True), # fparam_in_data +MODEL_STAT_CURATED_CASES = ( + (([], []), False), + (([], []), True), + (([[0, 1]], [1]), False), + (([[0, 1]], [1]), True), ) + + +@parameterized_cases(*MODEL_STAT_CURATED_CASES) @unittest.skipUnless(INSTALLED_PT and INSTALLED_PT_EXPT, "PT and PT_EXPT are required") class TestPolarComputeOrLoadStat(unittest.TestCase): """Test that compute_or_load_stat produces identical statistics on dp, pt, and pt_expt. diff --git a/source/tests/consistent/model/test_property.py b/source/tests/consistent/model/test_property.py index ea68c1e838..4d9f4ab3cd 100644 --- a/source/tests/consistent/model/test_property.py +++ b/source/tests/consistent/model/test_property.py @@ -27,7 +27,7 @@ INSTALLED_PT, INSTALLED_PT_EXPT, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( ModelTest, @@ -1210,10 +1210,15 @@ def test_get_observed_type_list(self) -> None: self.assertEqual(dp_observed, ["O"]) -@parameterized( - (([], []), ([[0, 1]], [1])), # (pair_exclude_types, atom_exclude_types) - (False, True), # fparam_in_data +MODEL_STAT_CURATED_CASES = ( + (([], []), False), + (([], []), True), + (([[0, 1]], [1]), False), + (([[0, 1]], [1]), True), ) + + +@parameterized_cases(*MODEL_STAT_CURATED_CASES) @unittest.skipUnless(INSTALLED_PT and INSTALLED_PT_EXPT, "PT and PT_EXPT are required") class TestPropertyComputeOrLoadStat(unittest.TestCase): """Test that compute_or_load_stat produces identical statistics on dp, pt, and pt_expt. diff --git a/source/tests/consistent/model/test_zbl_ener.py b/source/tests/consistent/model/test_zbl_ener.py index 0cf907bee2..ba4e687b16 100644 --- a/source/tests/consistent/model/test_zbl_ener.py +++ b/source/tests/consistent/model/test_zbl_ener.py @@ -30,7 +30,7 @@ INSTALLED_PT_EXPT, SKIP_FLAG, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( ModelTest, @@ -61,16 +61,20 @@ TESTS_DIR = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) -@parameterized( - ( - [], - [[0, 1]], - ), - ( - [], - [1], - ), +ZBL_MODEL_EXCLUSION_CURATED_CASES = ( + ([], []), + ([], [1]), + ([[0, 1]], []), + ([[0, 1]], [1]), ) + +ZBL_MODEL_STAT_CURATED_CASES = ( + (([], []),), + (([[0, 1]], [1]),), +) + + +@parameterized_cases(*ZBL_MODEL_EXCLUSION_CURATED_CASES) class TestEner(CommonTest, ModelTest, unittest.TestCase): @property def data(self) -> dict: @@ -1018,9 +1022,7 @@ def test_get_observed_type_list(self) -> None: self.assertEqual(dp_observed, ["O"]) -@parameterized( - (([], []), ([[0, 1]], [1])), # (pair_exclude_types, atom_exclude_types) -) +@parameterized_cases(*ZBL_MODEL_STAT_CURATED_CASES) @unittest.skipUnless(INSTALLED_PT and INSTALLED_PT_EXPT, "PT and PT_EXPT are required") class TestZBLComputeOrLoadStat(unittest.TestCase): """Test that compute_or_load_stat produces identical statistics on dp, pt, and pt_expt. diff --git a/source/tests/consistent/test_activation.py b/source/tests/consistent/test_activation.py index 803336c15c..23fd8eff52 100644 --- a/source/tests/consistent/test_activation.py +++ b/source/tests/consistent/test_activation.py @@ -21,7 +21,7 @@ INSTALLED_PT, INSTALLED_PT_EXPT, INSTALLED_TF, - parameterized, + parameterized_cases, ) if INSTALLED_PT: @@ -54,9 +54,10 @@ ) -@parameterized( - tuple([x.capitalize() for x in VALID_ACTIVATION]), -) +ACTIVATION_FUNCTION_CURATED_CASES = tuple((x.capitalize(),) for x in VALID_ACTIVATION) + + +@parameterized_cases(*ACTIVATION_FUNCTION_CURATED_CASES) class TestActivationFunctionConsistent(unittest.TestCase): def setUp(self) -> None: (self.activation,) = self.param @@ -117,14 +118,15 @@ def test_pt_expt_consistent_with_ref(self) -> None: np.testing.assert_allclose(self.ref, test, atol=1e-10) -@parameterized( - ( - "silut", # default threshold 3.0 - "silut:3.0", # explicit threshold 3.0 - "silut:10.0", # large threshold - "custom_silu:5.0", # alias - ), +SILUT_VARIANT_CURATED_CASES = ( + ("silut",), # default threshold 3.0 + ("silut:3.0",), # explicit threshold 3.0 + ("silut:10.0",), # large threshold + ("custom_silu:5.0",), # alias ) + + +@parameterized_cases(*SILUT_VARIANT_CURATED_CASES) class TestSilutVariantsConsistent(unittest.TestCase): """Cross-backend consistency for silut with different thresholds.""" diff --git a/source/tests/consistent/test_learning_rate.py b/source/tests/consistent/test_learning_rate.py index cd862c24ef..e7c5f77bd8 100644 --- a/source/tests/consistent/test_learning_rate.py +++ b/source/tests/consistent/test_learning_rate.py @@ -18,7 +18,7 @@ INSTALLED_ARRAY_API_STRICT, INSTALLED_JAX, INSTALLED_PT, - parameterized, + parameterized_cases, ) if INSTALLED_PT: @@ -34,7 +34,7 @@ import array_api_strict as xp -@parameterized( +LEARNING_RATE_CURATED_CASES = ( ( { "type": "exp", @@ -44,6 +44,8 @@ "num_steps": 1000000, "warmup_steps": 10000, }, + ), + ( { "type": "cosine", "start_lr": 1e-3, @@ -51,6 +53,8 @@ "num_steps": 1000000, "warmup_steps": 10000, }, + ), + ( { "type": "wsd", "start_lr": 1e-3, @@ -59,6 +63,8 @@ "warmup_steps": 10000, "decay_phase_ratio": 0.1, }, + ), + ( { "type": "wsd", "start_lr": 1e-3, @@ -68,6 +74,8 @@ "decay_phase_ratio": 0.1, "decay_type": "cosine", }, + ), + ( { "type": "wsd", "start_lr": 1e-3, @@ -79,6 +87,9 @@ }, ), ) + + +@parameterized_cases(*LEARNING_RATE_CURATED_CASES) class TestLearningRateConsistent(unittest.TestCase): """Test learning rate consistency across different array backends.""" diff --git a/source/tests/consistent/test_type_embedding.py b/source/tests/consistent/test_type_embedding.py index 9c1de0e8c5..44c2974964 100644 --- a/source/tests/consistent/test_type_embedding.py +++ b/source/tests/consistent/test_type_embedding.py @@ -21,7 +21,7 @@ INSTALLED_PT, INSTALLED_TF, CommonTest, - parameterized, + parameterized_cases, ) if INSTALLED_PT: @@ -55,13 +55,46 @@ TypeEmbedNetPD = object -@parameterized( - (True, False), # resnet_dt - ("float32", "float64"), # precision - (True, False), # padding - (True, False), # use_econf_tebd - (True, False), # use_tebd_bias +TYPE_EMBEDDING_CASE_FIELDS = ( + "resnet_dt", + "precision", + "padding", + "use_econf_tebd", + "use_tebd_bias", ) + +TYPE_EMBEDDING_BASELINE_CASE = { + "resnet_dt": True, + "precision": "float64", + "padding": True, + "use_econf_tebd": True, + "use_tebd_bias": True, +} + + +def type_embedding_case(**overrides: Any) -> tuple: + case = TYPE_EMBEDDING_BASELINE_CASE | overrides + return tuple(case[field] for field in TYPE_EMBEDDING_CASE_FIELDS) + + +TYPE_EMBEDDING_CURATED_CASES = ( + type_embedding_case(), + type_embedding_case(resnet_dt=False), + type_embedding_case(precision="float32"), + type_embedding_case(padding=False), + type_embedding_case(use_econf_tebd=False), + type_embedding_case(use_tebd_bias=False), + type_embedding_case( + resnet_dt=False, + precision="float32", + padding=False, + use_econf_tebd=False, + use_tebd_bias=False, + ), +) + + +@parameterized_cases(*TYPE_EMBEDDING_CURATED_CASES) class TestTypeEmbedding(CommonTest, unittest.TestCase): """Useful utilities for descriptor tests.""" From ccdeedbfc7ad7185231de683206b0a53de887ffb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Jun 2026 05:48:57 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/consistent/model/test_polar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/consistent/model/test_polar.py b/source/tests/consistent/model/test_polar.py index d949f5a862..811ef7ad41 100644 --- a/source/tests/consistent/model/test_polar.py +++ b/source/tests/consistent/model/test_polar.py @@ -1525,10 +1525,10 @@ def test_compute_stat(self) -> None: ) def test_load_stat_from_file(self) -> None: + import tempfile from copy import ( deepcopy, ) - import tempfile from pathlib import ( Path, ) From ff4f78865dc918468aa9aed4fdd5e6fb5f3d4dc6 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 30 Jun 2026 16:44:01 +0800 Subject: [PATCH 3/3] test(consistent): skip unsupported TF DPA1 smooth case --- source/tests/consistent/model/test_dpa1.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/source/tests/consistent/model/test_dpa1.py b/source/tests/consistent/model/test_dpa1.py index 3c284ba90c..fa6493e533 100644 --- a/source/tests/consistent/model/test_dpa1.py +++ b/source/tests/consistent/model/test_dpa1.py @@ -128,6 +128,11 @@ def get_reference_backend(self): return self.RefBackend.DP raise ValueError("No available reference") + @property + def skip_tf(self) -> bool: + _, smooth_type_embedding = self.param + return CommonTest.skip_tf or smooth_type_embedding + @property def skip_jax(self) -> bool: return not INSTALLED_JAX