diff --git a/source/tests/consistent/descriptor/test_se_atten_v2.py b/source/tests/consistent/descriptor/test_se_atten_v2.py index a7d0fa8e52..77a9d2fe5e 100644 --- a/source/tests/consistent/descriptor/test_se_atten_v2.py +++ b/source/tests/consistent/descriptor/test_se_atten_v2.py @@ -22,7 +22,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF2, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( DescriptorAPITest, @@ -70,26 +70,77 @@ descrpt_se_atten_args, ) +SE_ATTEN_V2_CASE_FIELDS = ( + "tebd_dim", + "resnet_dt", + "type_one_side", + "attn", + "attn_layer", + "attn_dotr", + "excluded_types", + "env_protection", + "set_davg_zero", + "scaling_factor", + "normalize", + "temperature", + "ln_eps", + "concat_output_tebd", + "precision", + "use_econf_tebd", + "use_tebd_bias", +) -@parameterized( - (4,), # tebd_dim - (True,), # resnet_dt - (True, False), # type_one_side - (20,), # attn - (0, 2), # attn_layer - (True, False), # attn_dotr - ([], [[0, 1]]), # excluded_types - (0.0,), # env_protection - (True, False), # set_davg_zero - (1.0,), # scaling_factor - (True, False), # normalize - (None, 1.0), # temperature - (1e-5,), # ln_eps - (True,), # concat_output_tebd - ("float64",), # precision - (True, False), # use_econf_tebd - (False,), # use_tebd_bias +SE_ATTEN_V2_BASELINE_CASE = { + "tebd_dim": 4, + "resnet_dt": True, + "type_one_side": True, + "attn": 20, + "attn_layer": 2, + "attn_dotr": True, + "excluded_types": [], + "env_protection": 0.0, + "set_davg_zero": True, + "scaling_factor": 1.0, + "normalize": True, + "temperature": None, + "ln_eps": 1e-5, + "concat_output_tebd": True, + "precision": "float64", + "use_econf_tebd": True, + "use_tebd_bias": False, +} + + +def se_atten_v2_case(**overrides: Any) -> tuple: + case = SE_ATTEN_V2_BASELINE_CASE | overrides + return tuple(case[field] for field in SE_ATTEN_V2_CASE_FIELDS) + + +SE_ATTEN_V2_CURATED_CASES = ( + se_atten_v2_case(), + se_atten_v2_case(type_one_side=False), + se_atten_v2_case(attn_layer=0, attn_dotr=False, normalize=False), + se_atten_v2_case(attn_dotr=False), + se_atten_v2_case(excluded_types=[[0, 1]]), + se_atten_v2_case(set_davg_zero=False), + se_atten_v2_case(normalize=False), + se_atten_v2_case(temperature=1.0), + se_atten_v2_case(use_econf_tebd=False), + se_atten_v2_case( + type_one_side=False, + attn_dotr=False, + excluded_types=[[0, 1]], + set_davg_zero=False, + normalize=False, + temperature=1.0, + use_econf_tebd=False, + ), ) + +SE_ATTEN_V2_DESCRIPTOR_API_CURATED_CASES = SE_ATTEN_V2_CURATED_CASES + + +@parameterized_cases(*SE_ATTEN_V2_CURATED_CASES) class TestSeAttenV2(CommonTest, DescriptorTest, unittest.TestCase): @property def data(self) -> dict: @@ -543,25 +594,7 @@ def atol(self) -> float: raise ValueError(f"Unknown precision: {precision}") -@parameterized( - (4,), # tebd_dim - (True,), # resnet_dt - (True, False), # type_one_side - (20,), # attn - (0, 2), # attn_layer - (True, False), # attn_dotr - ([], [[0, 1]]), # excluded_types - (0.0,), # env_protection - (True, False), # set_davg_zero - (1.0,), # scaling_factor - (True, False), # normalize - (None, 1.0), # temperature - (1e-5,), # ln_eps - (True,), # concat_output_tebd - ("float64",), # precision - (True, False), # use_econf_tebd - (False,), # use_tebd_bias -) +@parameterized_cases(*SE_ATTEN_V2_DESCRIPTOR_API_CURATED_CASES) class TestSeAttenV2DescriptorAPI(DescriptorAPITest, unittest.TestCase): """Test consistency of BaseDescriptor API methods across backends.""" diff --git a/source/tests/consistent/descriptor/test_se_r.py b/source/tests/consistent/descriptor/test_se_r.py index 4ceaec57d1..0f81395680 100644 --- a/source/tests/consistent/descriptor/test_se_r.py +++ b/source/tests/consistent/descriptor/test_se_r.py @@ -19,7 +19,7 @@ INSTALLED_TF, INSTALLED_TF2, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( DescriptorAPITest, @@ -58,12 +58,43 @@ DescrptSeRArrayAPIStrict = None -@parameterized( - (True, False), # resnet_dt - (True, False), # type_one_side - ([], [[0, 1]]), # excluded_types - ("float32", "float64"), # precision +SE_R_CASE_FIELDS = ( + "resnet_dt", + "type_one_side", + "excluded_types", + "precision", ) + +SE_R_BASELINE_CASE = { + "resnet_dt": True, + "type_one_side": True, + "excluded_types": [], + "precision": "float64", +} + + +def se_r_case(**overrides: Any) -> tuple: + case = SE_R_BASELINE_CASE | overrides + return tuple(case[field] for field in SE_R_CASE_FIELDS) + + +SE_R_CURATED_CASES = ( + se_r_case(), + se_r_case(resnet_dt=False), + se_r_case(type_one_side=False), + se_r_case(excluded_types=[[0, 1]]), + se_r_case(precision="float32"), +) + +SE_R_DESCRIPTOR_API_CURATED_CASES = ( + se_r_case(), + se_r_case(resnet_dt=False), + se_r_case(type_one_side=False), + se_r_case(excluded_types=[[0, 1]]), +) + + +@parameterized_cases(*SE_R_CURATED_CASES) class TestSeR(CommonTest, DescriptorTest, unittest.TestCase): @property def data(self) -> dict: @@ -289,12 +320,7 @@ def atol(self) -> float: raise ValueError(f"Unknown precision: {precision}") -@parameterized( - (True, False), # resnet_dt - (True, False), # type_one_side - ([], [[0, 1]]), # excluded_types - ("float64",), # precision -) +@parameterized_cases(*SE_R_DESCRIPTOR_API_CURATED_CASES) class TestSeRDescriptorAPI(DescriptorAPITest, unittest.TestCase): """Test consistency of BaseDescriptor API methods across backends.""" diff --git a/source/tests/consistent/descriptor/test_se_t.py b/source/tests/consistent/descriptor/test_se_t.py index 5980c04e87..3eebba009f 100644 --- a/source/tests/consistent/descriptor/test_se_t.py +++ b/source/tests/consistent/descriptor/test_se_t.py @@ -19,7 +19,7 @@ INSTALLED_TF, INSTALLED_TF2, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( DescriptorAPITest, @@ -54,13 +54,45 @@ descrpt_se_t_args, ) +SE_T_CASE_FIELDS = ( + "resnet_dt", + "excluded_types", + "precision", + "env_protection", +) + +SE_T_BASELINE_CASE = { + "resnet_dt": True, + "excluded_types": [], + "precision": "float64", + "env_protection": 0.0, +} + + +def se_t_case(**overrides: Any) -> tuple: + case = SE_T_BASELINE_CASE | overrides + return tuple(case[field] for field in SE_T_CASE_FIELDS) + -@parameterized( - (True, False), # resnet_dt - ([], [[0, 1]]), # excluded_types - ("float32", "float64"), # precision - (0.0, 1e-8, 1e-2), # env_protection +SE_T_CURATED_CASES = ( + se_t_case(), + se_t_case(resnet_dt=False), + se_t_case(excluded_types=[[0, 1]]), + se_t_case(precision="float32"), + se_t_case(env_protection=1e-8), + se_t_case(env_protection=1e-2), ) + +SE_T_DESCRIPTOR_API_CURATED_CASES = ( + se_t_case(), + se_t_case(resnet_dt=False), + se_t_case(excluded_types=[[0, 1]]), + se_t_case(env_protection=1e-8), + se_t_case(env_protection=1e-2), +) + + +@parameterized_cases(*SE_T_CURATED_CASES) class TestSeT(CommonTest, DescriptorTest, unittest.TestCase): @property def data(self) -> dict: @@ -280,12 +312,7 @@ def atol(self) -> float: raise ValueError(f"Unknown precision: {precision}") -@parameterized( - (True, False), # resnet_dt - ([], [[0, 1]]), # excluded_types - ("float64",), # precision - (0.0, 1e-8, 1e-2), # env_protection -) +@parameterized_cases(*SE_T_DESCRIPTOR_API_CURATED_CASES) class TestSeTDescriptorAPI(DescriptorAPITest, unittest.TestCase): """Test consistency of BaseDescriptor API methods across backends.""" diff --git a/source/tests/consistent/descriptor/test_se_t_tebd.py b/source/tests/consistent/descriptor/test_se_t_tebd.py index b9d10d53a9..42b3e81bea 100644 --- a/source/tests/consistent/descriptor/test_se_t_tebd.py +++ b/source/tests/consistent/descriptor/test_se_t_tebd.py @@ -22,7 +22,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF2, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( DescriptorAPITest, @@ -62,20 +62,60 @@ descrpt_se_e3_tebd_args, ) +SE_T_TEBD_CASE_FIELDS = ( + "tebd_dim", + "tebd_input_mode", + "resnet_dt", + "excluded_types", + "env_protection", + "set_davg_zero", + "smooth", + "concat_output_tebd", + "precision", + "use_econf_tebd", + "use_tebd_bias", +) + +SE_T_TEBD_BASELINE_CASE = { + "tebd_dim": 4, + "tebd_input_mode": "strip", + "resnet_dt": True, + "excluded_types": [], + "env_protection": 0.0, + "set_davg_zero": True, + "smooth": True, + "concat_output_tebd": True, + "precision": "float64", + "use_econf_tebd": True, + "use_tebd_bias": False, +} + + +def se_t_tebd_case(**overrides: Any) -> tuple: + case = SE_T_TEBD_BASELINE_CASE | overrides + return tuple(case[field] for field in SE_T_TEBD_CASE_FIELDS) + -@parameterized( - (4,), # tebd_dim - ("strip",), # tebd_input_mode - (True,), # resnet_dt - ([], [[0, 1]]), # excluded_types - (0.0,), # env_protection - (True, False), # set_davg_zero - (True, False), # smooth - (True,), # concat_output_tebd - ("float64",), # precision - (True, False), # use_econf_tebd - (False, True), # use_tebd_bias +SE_T_TEBD_CURATED_CASES = ( + se_t_tebd_case(), + se_t_tebd_case(excluded_types=[[0, 1]]), + se_t_tebd_case(set_davg_zero=False), + se_t_tebd_case(smooth=False), + se_t_tebd_case(use_econf_tebd=False), + se_t_tebd_case(use_tebd_bias=True), + se_t_tebd_case( + excluded_types=[[0, 1]], + set_davg_zero=False, + smooth=False, + use_econf_tebd=False, + use_tebd_bias=True, + ), ) + +SE_T_TEBD_DESCRIPTOR_API_CURATED_CASES = SE_T_TEBD_CURATED_CASES + + +@parameterized_cases(*SE_T_TEBD_CURATED_CASES) class TestSeTTebd(CommonTest, DescriptorTest, unittest.TestCase): @property def data(self) -> dict: @@ -374,19 +414,7 @@ def atol(self) -> float: raise ValueError(f"Unknown precision: {precision}") -@parameterized( - (4,), # tebd_dim - ("strip",), # tebd_input_mode - (True,), # resnet_dt - ([], [[0, 1]]), # excluded_types - (0.0,), # env_protection - (True, False), # set_davg_zero - (True, False), # smooth - (True,), # concat_output_tebd - ("float64",), # precision - (True, False), # use_econf_tebd - (False, True), # use_tebd_bias -) +@parameterized_cases(*SE_T_TEBD_DESCRIPTOR_API_CURATED_CASES) class TestSeTTebdDescriptorAPI(DescriptorAPITest, unittest.TestCase): """Test consistency of BaseDescriptor API methods across backends.""" diff --git a/source/tests/consistent/fitting/test_dipole.py b/source/tests/consistent/fitting/test_dipole.py index 4d052cac73..b470e1699d 100644 --- a/source/tests/consistent/fitting/test_dipole.py +++ b/source/tests/consistent/fitting/test_dipole.py @@ -22,7 +22,7 @@ INSTALLED_TF, INSTALLED_TF2, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( DipoleFittingTest, @@ -72,13 +72,42 @@ fitting_dipole, ) +DIPOLE_FITTING_CASE_FIELDS = ( + "resnet_dt", + "precision", + "mixed_types", + "sel_type", +) + +DIPOLE_FITTING_BASELINE_CASE = { + "resnet_dt": True, + "precision": "float64", + "mixed_types": True, + "sel_type": None, +} + + +def dipole_fitting_case(**overrides: Any) -> tuple: + case = DIPOLE_FITTING_BASELINE_CASE | overrides + return tuple(case[field] for field in DIPOLE_FITTING_CASE_FIELDS) -@parameterized( - (True, False), # resnet_dt - ("float64", "float32"), # precision - (True, False), # mixed_types - (None, [0]), # sel_type + +DIPOLE_FITTING_CURATED_CASES = ( + dipole_fitting_case(), + dipole_fitting_case(resnet_dt=False), + dipole_fitting_case(precision="float32"), + dipole_fitting_case(mixed_types=False), + dipole_fitting_case(sel_type=[0]), + dipole_fitting_case( + resnet_dt=False, + precision="float32", + mixed_types=False, + sel_type=[0], + ), ) + + +@parameterized_cases(*DIPOLE_FITTING_CURATED_CASES) class TestDipole(CommonTest, DipoleFittingTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/fitting/test_dos.py b/source/tests/consistent/fitting/test_dos.py index a1f8530c5d..dc27dca9d7 100644 --- a/source/tests/consistent/fitting/test_dos.py +++ b/source/tests/consistent/fitting/test_dos.py @@ -22,7 +22,7 @@ INSTALLED_TF, INSTALLED_TF2, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( FittingTest, @@ -70,14 +70,50 @@ DOSFittingStrict = object -@parameterized( - (True, False), # resnet_dt - ("float64", "float32"), # precision - (True, False), # mixed_types - (0, 1), # numb_fparam - (0, 1), # numb_aparam - (10, 20), # numb_dos +DOS_FITTING_CASE_FIELDS = ( + "resnet_dt", + "precision", + "mixed_types", + "numb_fparam", + "numb_aparam", + "numb_dos", ) + +DOS_FITTING_BASELINE_CASE = { + "resnet_dt": True, + "precision": "float64", + "mixed_types": True, + "numb_fparam": 0, + "numb_aparam": 0, + "numb_dos": 10, +} + + +def dos_fitting_case(**overrides: Any) -> tuple: + case = DOS_FITTING_BASELINE_CASE | overrides + return tuple(case[field] for field in DOS_FITTING_CASE_FIELDS) + + +DOS_FITTING_CURATED_CASES = ( + dos_fitting_case(), + dos_fitting_case(resnet_dt=False), + dos_fitting_case(precision="float32"), + dos_fitting_case(mixed_types=False), + dos_fitting_case(numb_fparam=1), + dos_fitting_case(numb_aparam=1), + dos_fitting_case(numb_dos=20), + dos_fitting_case( + resnet_dt=False, + precision="float32", + mixed_types=False, + numb_fparam=1, + numb_aparam=1, + numb_dos=20, + ), +) + + +@parameterized_cases(*DOS_FITTING_CURATED_CASES) class TestDOS(CommonTest, FittingTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/fitting/test_dpa4_ener.py b/source/tests/consistent/fitting/test_dpa4_ener.py index 64b7918919..ee10b54343 100644 --- a/source/tests/consistent/fitting/test_dpa4_ener.py +++ b/source/tests/consistent/fitting/test_dpa4_ener.py @@ -18,7 +18,7 @@ INSTALLED_PT, INSTALLED_PT_EXPT, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( FittingTest, @@ -45,10 +45,15 @@ SeZMEnerFittingTF = None -@parameterized( - ("float64", "float32"), # precision - ([0], [16, 16]), # neuron ([0] = auto-width placeholder) +DPA4_ENER_FITTING_CURATED_CASES = ( + ("float64", [0]), + ("float64", [16, 16]), + ("float32", [0]), + ("float32", [16, 16]), ) + + +@parameterized_cases(*DPA4_ENER_FITTING_CURATED_CASES) class TestDPA4Ener(CommonTest, FittingTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/fitting/test_ener.py b/source/tests/consistent/fitting/test_ener.py index 1c2caa7027..74de77cb78 100644 --- a/source/tests/consistent/fitting/test_ener.py +++ b/source/tests/consistent/fitting/test_ener.py @@ -23,7 +23,7 @@ INSTALLED_TF, INSTALLED_TF2, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( FittingTest, @@ -82,14 +82,57 @@ EnerFittingStrict = None -@parameterized( - (True, False), # resnet_dt - ("float64", "float32", "bfloat16"), # precision - (True, False), # mixed_types - ((0, None), (1, None), (1, [1.0])), # (numb_fparam, default_fparam) - ((0, False), (1, False), (1, True)), # (numb_aparam, use_aparam_as_mask) - ([], [-12345.6, None]), # atom_ener +ENER_FITTING_CASE_FIELDS = ( + "resnet_dt", + "precision", + "mixed_types", + "fparam", + "aparam", + "atom_ener", ) + +ENER_FITTING_BASELINE_CASE = { + "resnet_dt": True, + "precision": "float64", + "mixed_types": True, + "fparam": (0, None), + "aparam": (0, False), + "atom_ener": [], +} + + +def ener_fitting_case(**overrides: Any) -> tuple: + case = ENER_FITTING_BASELINE_CASE | overrides + return tuple(case[field] for field in ENER_FITTING_CASE_FIELDS) + + +ENER_FITTING_CURATED_CASES = ( + ener_fitting_case(), + ener_fitting_case(resnet_dt=False), + ener_fitting_case(precision="float32"), + ener_fitting_case(precision="bfloat16"), + ener_fitting_case(mixed_types=False), + ener_fitting_case(fparam=(1, None)), + ener_fitting_case(fparam=(1, [1.0])), + ener_fitting_case(aparam=(1, False)), + ener_fitting_case(aparam=(1, True)), + ener_fitting_case(atom_ener=[-12345.6, None]), + ener_fitting_case( + resnet_dt=False, + precision="float32", + mixed_types=False, + fparam=(1, [1.0]), + aparam=(1, True), + atom_ener=[-12345.6, None], + ), +) + +ENER_FITTING_STAT_CURATED_CASES = ( + ener_fitting_case(fparam=(3, None), aparam=(3, False)), +) + + +@parameterized_cases(*ENER_FITTING_CURATED_CASES) class TestEner(CommonTest, FittingTest, unittest.TestCase): @property def data(self) -> dict: @@ -467,14 +510,7 @@ def atol(self) -> float: raise ValueError(f"Unknown precision: {precision}") -@parameterized( - (True,), # resnet_dt - ("float64",), # precision - (True,), # mixed_types - ((3, None),), # (numb_fparam, default_fparam) - ((3, False),), # (numb_aparam, use_aparam_as_mask) - ([],), # atom_ener -) +@parameterized_cases(*ENER_FITTING_STAT_CURATED_CASES) class TestEnerStat(CommonTest, FittingTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/fitting/test_polar.py b/source/tests/consistent/fitting/test_polar.py index 80bf369848..a9676eb950 100644 --- a/source/tests/consistent/fitting/test_polar.py +++ b/source/tests/consistent/fitting/test_polar.py @@ -22,7 +22,7 @@ INSTALLED_TF, INSTALLED_TF2, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( DipoleFittingTest, @@ -72,12 +72,34 @@ fitting_polar, ) +POLAR_FITTING_CASE_FIELDS = ( + "resnet_dt", + "precision", + "mixed_types", +) + +POLAR_FITTING_BASELINE_CASE = { + "resnet_dt": True, + "precision": "float64", + "mixed_types": True, +} + + +def polar_fitting_case(**overrides: Any) -> tuple: + case = POLAR_FITTING_BASELINE_CASE | overrides + return tuple(case[field] for field in POLAR_FITTING_CASE_FIELDS) -@parameterized( - (True, False), # resnet_dt - ("float64", "float32"), # precision - (True, False), # mixed_types + +POLAR_FITTING_CURATED_CASES = ( + polar_fitting_case(), + polar_fitting_case(resnet_dt=False), + polar_fitting_case(precision="float32"), + polar_fitting_case(mixed_types=False), + polar_fitting_case(resnet_dt=False, precision="float32", mixed_types=False), ) + + +@parameterized_cases(*POLAR_FITTING_CURATED_CASES) class TestPolar(CommonTest, DipoleFittingTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/fitting/test_property.py b/source/tests/consistent/fitting/test_property.py index 02e49ede4b..853e7bb974 100644 --- a/source/tests/consistent/fitting/test_property.py +++ b/source/tests/consistent/fitting/test_property.py @@ -26,7 +26,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF2, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( FittingTest, @@ -72,15 +72,54 @@ PropertyFittingTF2 = object -@parameterized( - (True, False), # resnet_dt - ("float64", "float32"), # precision - (True, False), # mixed_types - (0, 1), # numb_fparam - (0, 1), # numb_aparam - (1, 3), # task_dim - (True, False), # intensive +PROPERTY_FITTING_CASE_FIELDS = ( + "resnet_dt", + "precision", + "mixed_types", + "numb_fparam", + "numb_aparam", + "task_dim", + "intensive", ) + +PROPERTY_FITTING_BASELINE_CASE = { + "resnet_dt": True, + "precision": "float64", + "mixed_types": True, + "numb_fparam": 0, + "numb_aparam": 0, + "task_dim": 1, + "intensive": True, +} + + +def property_fitting_case(**overrides: Any) -> tuple: + case = PROPERTY_FITTING_BASELINE_CASE | overrides + return tuple(case[field] for field in PROPERTY_FITTING_CASE_FIELDS) + + +PROPERTY_FITTING_CURATED_CASES = ( + property_fitting_case(), + property_fitting_case(resnet_dt=False), + property_fitting_case(precision="float32"), + property_fitting_case(mixed_types=False), + property_fitting_case(numb_fparam=1), + property_fitting_case(numb_aparam=1), + property_fitting_case(task_dim=3), + property_fitting_case(intensive=False), + property_fitting_case( + resnet_dt=False, + precision="float32", + mixed_types=False, + numb_fparam=1, + numb_aparam=1, + task_dim=3, + intensive=False, + ), +) + + +@parameterized_cases(*PROPERTY_FITTING_CURATED_CASES) class TestProperty(CommonTest, FittingTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/loss/test_dos.py b/source/tests/consistent/loss/test_dos.py index a920f07728..4d03ee9cce 100644 --- a/source/tests/consistent/loss/test_dos.py +++ b/source/tests/consistent/loss/test_dos.py @@ -21,7 +21,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF2, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( LossTest, @@ -45,10 +45,14 @@ import array_api_strict -@parameterized( - (1.0, 0.0), # pref_dos - (1.0, 0.0), # pref_ados +DOS_LOSS_CURATED_CASES = ( + (1.0, 1.0), + (1.0, 0.0), + (0.0, 1.0), ) + + +@parameterized_cases(*DOS_LOSS_CURATED_CASES) class TestDOS(CommonTest, LossTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/loss/test_ener.py b/source/tests/consistent/loss/test_ener.py index c167a9f43d..58addceb41 100644 --- a/source/tests/consistent/loss/test_ener.py +++ b/source/tests/consistent/loss/test_ener.py @@ -23,7 +23,7 @@ INSTALLED_TF, INSTALLED_TF2, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( LossTest, @@ -62,14 +62,50 @@ import array_api_strict -@parameterized( - (False, True), # huber - (False, True), # enable_atom_ener_coeff - ("mse", "mae"), # loss_func - (False, True), # f_use_norm - (False, True), # mae (dp test extra MAE metrics) - (False, True), # intensive_ener_virial +ENER_LOSS_CASE_FIELDS = ( + "use_huber", + "enable_atom_ener_coeff", + "loss_func", + "f_use_norm", + "mae", + "intensive_ener_virial", ) + +ENER_LOSS_BASELINE_CASE = { + "use_huber": False, + "enable_atom_ener_coeff": False, + "loss_func": "mse", + "f_use_norm": False, + "mae": False, + "intensive_ener_virial": False, +} + + +def ener_loss_case(**overrides: Any) -> tuple: + case = ENER_LOSS_BASELINE_CASE | overrides + return tuple(case[field] for field in ENER_LOSS_CASE_FIELDS) + + +ENER_LOSS_CURATED_CASES = ( + ener_loss_case(), + ener_loss_case(use_huber=True), + ener_loss_case(enable_atom_ener_coeff=True), + ener_loss_case(loss_func="mae"), + ener_loss_case(f_use_norm=True), + ener_loss_case(mae=True), + ener_loss_case(intensive_ener_virial=True), + ener_loss_case( + use_huber=True, + enable_atom_ener_coeff=True, + loss_func="mae", + f_use_norm=True, + mae=True, + intensive_ener_virial=True, + ), +) + + +@parameterized_cases(*ENER_LOSS_CURATED_CASES) class TestEner(CommonTest, LossTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/loss/test_ener_spin.py b/source/tests/consistent/loss/test_ener_spin.py index dacb354e68..9f8137be73 100644 --- a/source/tests/consistent/loss/test_ener_spin.py +++ b/source/tests/consistent/loss/test_ener_spin.py @@ -21,7 +21,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF2, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( LossTest, @@ -45,11 +45,16 @@ import array_api_strict -@parameterized( - ("mse", "mae"), # loss_func - (False, True), # mae (dp test extra MAE metrics) - (False, True), # intensive_ener_virial +ENER_SPIN_LOSS_CURATED_CASES = ( + ("mse", False, False), + ("mae", False, False), + ("mse", True, False), + ("mse", False, True), + ("mae", False, True), ) + + +@parameterized_cases(*ENER_SPIN_LOSS_CURATED_CASES) class TestEnerSpin(CommonTest, LossTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/loss/test_property.py b/source/tests/consistent/loss/test_property.py index 1bb838cc5a..d0817934d2 100644 --- a/source/tests/consistent/loss/test_property.py +++ b/source/tests/consistent/loss/test_property.py @@ -21,7 +21,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF2, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( LossTest, @@ -45,9 +45,16 @@ import array_api_strict -@parameterized( - ("smooth_mae", "mae", "mse", "rmse", "mape"), # loss_func +PROPERTY_LOSS_CURATED_CASES = ( + ("smooth_mae",), + ("mae",), + ("mse",), + ("rmse",), + ("mape",), ) + + +@parameterized_cases(*PROPERTY_LOSS_CURATED_CASES) class TestProperty(CommonTest, LossTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/loss/test_tensor.py b/source/tests/consistent/loss/test_tensor.py index 1c41fe425c..a6eb9b8d38 100644 --- a/source/tests/consistent/loss/test_tensor.py +++ b/source/tests/consistent/loss/test_tensor.py @@ -21,7 +21,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF2, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( LossTest, @@ -45,10 +45,14 @@ import array_api_strict -@parameterized( - (1.0, 0.0), # pref - (1.0, 0.0), # pref_atomic +TENSOR_LOSS_CURATED_CASES = ( + (1.0, 1.0), + (1.0, 0.0), + (0.0, 1.0), ) + + +@parameterized_cases(*TENSOR_LOSS_CURATED_CASES) class TestTensor(CommonTest, LossTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/model/test_dipole.py b/source/tests/consistent/model/test_dipole.py index bfa4bb2411..68f2ede6ba 100644 --- a/source/tests/consistent/model/test_dipole.py +++ b/source/tests/consistent/model/test_dipole.py @@ -29,7 +29,7 @@ INSTALLED_TF, INSTALLED_TF2, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( ModelTest, @@ -1245,10 +1245,15 @@ def test_get_observed_type_list(self) -> None: self.assertEqual(dp_observed, ["O"]) -@parameterized( - (([], []), ([[0, 1]], [1])), # (pair_exclude_types, atom_exclude_types) - (False, True), # fparam_in_data +MODEL_STAT_CURATED_CASES = ( + (([], []), False), + (([], []), True), + (([[0, 1]], [1]), False), + (([[0, 1]], [1]), True), ) + + +@parameterized_cases(*MODEL_STAT_CURATED_CASES) @unittest.skipUnless(INSTALLED_PT and INSTALLED_PT_EXPT, "PT and PT_EXPT are required") class TestDipoleComputeOrLoadStat(unittest.TestCase): """Test that compute_or_load_stat produces identical statistics on dp, pt, and pt_expt. diff --git a/source/tests/consistent/model/test_dos.py b/source/tests/consistent/model/test_dos.py index dfbceeaf4c..7d34c9da53 100644 --- a/source/tests/consistent/model/test_dos.py +++ b/source/tests/consistent/model/test_dos.py @@ -29,7 +29,7 @@ INSTALLED_TF, INSTALLED_TF2, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( ModelTest, @@ -1234,10 +1234,15 @@ def test_get_observed_type_list(self) -> None: self.assertEqual(dp_observed, ["O"]) -@parameterized( - (([], []), ([[0, 1]], [1])), # (pair_exclude_types, atom_exclude_types) - (False, True), # fparam_in_data +MODEL_STAT_CURATED_CASES = ( + (([], []), False), + (([], []), True), + (([[0, 1]], [1]), False), + (([[0, 1]], [1]), True), ) + + +@parameterized_cases(*MODEL_STAT_CURATED_CASES) @unittest.skipUnless(INSTALLED_PT and INSTALLED_PT_EXPT, "PT and PT_EXPT are required") class TestDOSComputeOrLoadStat(unittest.TestCase): """Test that compute_or_load_stat produces identical statistics on dp, pt, and pt_expt. diff --git a/source/tests/consistent/model/test_dpa1.py b/source/tests/consistent/model/test_dpa1.py index b32570d024..fa6493e533 100644 --- a/source/tests/consistent/model/test_dpa1.py +++ b/source/tests/consistent/model/test_dpa1.py @@ -20,7 +20,7 @@ INSTALLED_TF, SKIP_FLAG, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( ModelTest, @@ -55,11 +55,14 @@ EnergyModelJAX = None -@parameterized( - ("strip", "concat"), # tebd_input_mode - # strip + smooth is inconsistent - (False,), # smooth +DPA1_ENER_CURATED_CASES = ( + ("strip", False), + ("concat", False), + ("concat", True), ) + + +@parameterized_cases(*DPA1_ENER_CURATED_CASES) class TestDPA1Ener(CommonTest, ModelTest, unittest.TestCase): @property def data(self) -> dict: @@ -125,6 +128,11 @@ def get_reference_backend(self): return self.RefBackend.DP raise ValueError("No available reference") + @property + def skip_tf(self) -> bool: + _, smooth_type_embedding = self.param + return CommonTest.skip_tf or smooth_type_embedding + @property def skip_jax(self) -> bool: return not INSTALLED_JAX diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index 33e2c0c27b..800eaa072d 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -31,7 +31,7 @@ INSTALLED_TF2, SKIP_FLAG, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( ModelTest, @@ -80,16 +80,28 @@ EnergyModelJAX = None -@parameterized( - ( - [], - [[0, 1]], - ), - ( - [], - [1], - ), +MODEL_EXCLUSION_CURATED_CASES = ( + ([], []), + ([], [1]), + ([[0, 1]], []), + ([[0, 1]], [1]), ) + +MODEL_STAT_CURATED_CASES = ( + (([], []), False), + (([], []), True), + (([[0, 1]], [1]), False), + (([[0, 1]], [1]), True), +) + +ENER_CHG_SPIN_EBD_FPARAM_CURATED_CASES = ( + ("no_chg_spin",), + ("explicit_chg_spin",), + ("default_chg_spin",), +) + + +@parameterized_cases(*MODEL_EXCLUSION_CURATED_CASES) class TestEner(CommonTest, ModelTest, unittest.TestCase): @property def data(self) -> dict: @@ -328,16 +340,7 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: raise ValueError(f"Unknown backend: {backend}") -@parameterized( - ( - [], - [[0, 1]], - ), - ( - [], - [1], - ), -) +@parameterized_cases(*MODEL_EXCLUSION_CURATED_CASES) class TestEnerLower(CommonTest, ModelTest, unittest.TestCase): @property def data(self) -> dict: @@ -1718,10 +1721,7 @@ def test_get_observed_type_list(self) -> None: self.assertEqual(dp_observed, ["O"]) -@parameterized( - (([], []), ([[0, 1]], [1])), # (pair_exclude_types, atom_exclude_types) - (False, True), # fparam_in_data -) +@parameterized_cases(*MODEL_STAT_CURATED_CASES) @unittest.skipUnless(INSTALLED_PT and INSTALLED_PT_EXPT, "PT and PT_EXPT are required") class TestEnerComputeOrLoadStat(unittest.TestCase): """Test that compute_or_load_stat produces identical statistics on dp, pt, and pt_expt. @@ -2064,9 +2064,7 @@ def raise_error(): compare_variables_recursive(dp_ser_loaded, pe_ser_loaded) -@parameterized( - ("no_chg_spin", "explicit_chg_spin", "default_chg_spin"), # cs_mode -) +@parameterized_cases(*ENER_CHG_SPIN_EBD_FPARAM_CURATED_CASES) @unittest.skipUnless(INSTALLED_PT and INSTALLED_PT_EXPT, "PT and PT_EXPT are required") class TestEnerChgSpinEbdFparam(unittest.TestCase): """Test dp/pt/pt_expt model forward consistency for add_chg_spin_ebd with three modes. diff --git a/source/tests/consistent/model/test_frozen.py b/source/tests/consistent/model/test_frozen.py index c19b1a8cb4..4ab62c699c 100644 --- a/source/tests/consistent/model/test_frozen.py +++ b/source/tests/consistent/model/test_frozen.py @@ -18,7 +18,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( ModelTest, @@ -50,6 +50,12 @@ tf_model = "deeppot_for_consistent_frozen.pb" dp_model = "deeppot_for_consistent_frozen.dp" +FROZEN_MODEL_CURATED_CASES = ( + (pt_model,), + (tf_model,), + (dp_model,), +) + def setUpModule() -> None: case = get_cases()["se_e2_a"] @@ -72,7 +78,7 @@ def tearDownModule() -> None: os.remove(tmp_pb) -@parameterized((pt_model, tf_model, dp_model)) +@parameterized_cases(*FROZEN_MODEL_CURATED_CASES) class TestFrozen(CommonTest, ModelTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/model/test_linear_ener.py b/source/tests/consistent/model/test_linear_ener.py index 9583a03668..1d3708f31b 100644 --- a/source/tests/consistent/model/test_linear_ener.py +++ b/source/tests/consistent/model/test_linear_ener.py @@ -15,7 +15,7 @@ INSTALLED_PT_EXPT, SKIP_FLAG, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( ModelTest, @@ -41,17 +41,15 @@ model_args, ) - -@parameterized( - ( - [], - [[0, 1]], - ), - ( - [], - [1], - ), +LINEAR_ENER_MODEL_EXCLUSION_CURATED_CASES = ( + ([], []), + ([], [1]), + ([[0, 1]], []), + ([[0, 1]], [1]), ) + + +@parameterized_cases(*LINEAR_ENER_MODEL_EXCLUSION_CURATED_CASES) class TestLinearEner(CommonTest, ModelTest, unittest.TestCase): @property def data(self) -> dict: diff --git a/source/tests/consistent/model/test_polar.py b/source/tests/consistent/model/test_polar.py index b4eac70735..811ef7ad41 100644 --- a/source/tests/consistent/model/test_polar.py +++ b/source/tests/consistent/model/test_polar.py @@ -29,7 +29,7 @@ INSTALLED_TF, INSTALLED_TF2, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( ModelTest, @@ -1239,10 +1239,15 @@ def test_get_observed_type_list(self) -> None: self.assertEqual(dp_observed, ["O"]) -@parameterized( - (([], []), ([[0, 1]], [1])), # (pair_exclude_types, atom_exclude_types) - (False, True), # fparam_in_data +MODEL_STAT_CURATED_CASES = ( + (([], []), False), + (([], []), True), + (([[0, 1]], [1]), False), + (([[0, 1]], [1]), True), ) + + +@parameterized_cases(*MODEL_STAT_CURATED_CASES) @unittest.skipUnless(INSTALLED_PT and INSTALLED_PT_EXPT, "PT and PT_EXPT are required") class TestPolarComputeOrLoadStat(unittest.TestCase): """Test that compute_or_load_stat produces identical statistics on dp, pt, and pt_expt. @@ -1458,9 +1463,15 @@ def test_compute_stat(self) -> None: ) # 2. Run compute_or_load_stat on all three backends - self.dp_model.compute_or_load_stat(lambda: self.np_sampled) - self.pt_model.compute_or_load_stat(lambda: self.pt_sampled) - self.pt_expt_model.compute_or_load_stat(lambda: self.np_sampled) + # deepcopy because stat.py mutates natoms in-place when atom_exclude_types + # is non-empty (natoms[:, 2:] *= type_mask). + from copy import ( + deepcopy, + ) + + self.dp_model.compute_or_load_stat(lambda: deepcopy(self.np_sampled)) + self.pt_model.compute_or_load_stat(lambda: deepcopy(self.pt_sampled)) + self.pt_expt_model.compute_or_load_stat(lambda: deepcopy(self.np_sampled)) # 3. Serialize all three and compare @variables dp_ser = self.dp_model.serialize() @@ -1515,6 +1526,9 @@ def test_compute_stat(self) -> None: def test_load_stat_from_file(self) -> None: import tempfile + from copy import ( + deepcopy, + ) from pathlib import ( Path, ) @@ -1536,13 +1550,13 @@ def test_load_stat_from_file(self) -> None: # 1. Compute stats and save to file self.dp_model.compute_or_load_stat( - lambda: self.np_sampled, stat_file_path=DPPath(dp_h5, "a") + lambda: deepcopy(self.np_sampled), stat_file_path=DPPath(dp_h5, "a") ) self.pt_model.compute_or_load_stat( - lambda: self.pt_sampled, stat_file_path=DPPath(pt_h5, "a") + lambda: deepcopy(self.pt_sampled), stat_file_path=DPPath(pt_h5, "a") ) self.pt_expt_model.compute_or_load_stat( - lambda: self.np_sampled, stat_file_path=DPPath(pe_h5, "a") + lambda: deepcopy(self.np_sampled), stat_file_path=DPPath(pe_h5, "a") ) # Save the computed serializations as reference diff --git a/source/tests/consistent/model/test_property.py b/source/tests/consistent/model/test_property.py index 28affedc11..047d6a76ed 100644 --- a/source/tests/consistent/model/test_property.py +++ b/source/tests/consistent/model/test_property.py @@ -28,7 +28,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF2, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( ModelTest, @@ -1231,10 +1231,15 @@ def test_get_observed_type_list(self) -> None: self.assertEqual(dp_observed, ["O"]) -@parameterized( - (([], []), ([[0, 1]], [1])), # (pair_exclude_types, atom_exclude_types) - (False, True), # fparam_in_data +MODEL_STAT_CURATED_CASES = ( + (([], []), False), + (([], []), True), + (([[0, 1]], [1]), False), + (([[0, 1]], [1]), True), ) + + +@parameterized_cases(*MODEL_STAT_CURATED_CASES) @unittest.skipUnless(INSTALLED_PT and INSTALLED_PT_EXPT, "PT and PT_EXPT are required") class TestPropertyComputeOrLoadStat(unittest.TestCase): """Test that compute_or_load_stat produces identical statistics on dp, pt, and pt_expt. diff --git a/source/tests/consistent/model/test_zbl_ener.py b/source/tests/consistent/model/test_zbl_ener.py index 0cf907bee2..ba4e687b16 100644 --- a/source/tests/consistent/model/test_zbl_ener.py +++ b/source/tests/consistent/model/test_zbl_ener.py @@ -30,7 +30,7 @@ INSTALLED_PT_EXPT, SKIP_FLAG, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( ModelTest, @@ -61,16 +61,20 @@ TESTS_DIR = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) -@parameterized( - ( - [], - [[0, 1]], - ), - ( - [], - [1], - ), +ZBL_MODEL_EXCLUSION_CURATED_CASES = ( + ([], []), + ([], [1]), + ([[0, 1]], []), + ([[0, 1]], [1]), ) + +ZBL_MODEL_STAT_CURATED_CASES = ( + (([], []),), + (([[0, 1]], [1]),), +) + + +@parameterized_cases(*ZBL_MODEL_EXCLUSION_CURATED_CASES) class TestEner(CommonTest, ModelTest, unittest.TestCase): @property def data(self) -> dict: @@ -1018,9 +1022,7 @@ def test_get_observed_type_list(self) -> None: self.assertEqual(dp_observed, ["O"]) -@parameterized( - (([], []), ([[0, 1]], [1])), # (pair_exclude_types, atom_exclude_types) -) +@parameterized_cases(*ZBL_MODEL_STAT_CURATED_CASES) @unittest.skipUnless(INSTALLED_PT and INSTALLED_PT_EXPT, "PT and PT_EXPT are required") class TestZBLComputeOrLoadStat(unittest.TestCase): """Test that compute_or_load_stat produces identical statistics on dp, pt, and pt_expt. diff --git a/source/tests/consistent/test_activation.py b/source/tests/consistent/test_activation.py index 21194ce66b..1c388b7305 100644 --- a/source/tests/consistent/test_activation.py +++ b/source/tests/consistent/test_activation.py @@ -22,7 +22,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF, INSTALLED_TF2, - parameterized, + parameterized_cases, ) if INSTALLED_PT: @@ -55,9 +55,10 @@ ) -@parameterized( - tuple([x.capitalize() for x in VALID_ACTIVATION]), -) +ACTIVATION_FUNCTION_CURATED_CASES = tuple((x.capitalize(),) for x in VALID_ACTIVATION) + + +@parameterized_cases(*ACTIVATION_FUNCTION_CURATED_CASES) class TestActivationFunctionConsistent(unittest.TestCase): def setUp(self) -> None: (self.activation,) = self.param @@ -128,14 +129,15 @@ def test_pt_expt_consistent_with_ref(self) -> None: np.testing.assert_allclose(self.ref, test, atol=1e-10) -@parameterized( - ( - "silut", # default threshold 3.0 - "silut:3.0", # explicit threshold 3.0 - "silut:10.0", # large threshold - "custom_silu:5.0", # alias - ), +SILUT_VARIANT_CURATED_CASES = ( + ("silut",), # default threshold 3.0 + ("silut:3.0",), # explicit threshold 3.0 + ("silut:10.0",), # large threshold + ("custom_silu:5.0",), # alias ) + + +@parameterized_cases(*SILUT_VARIANT_CURATED_CASES) class TestSilutVariantsConsistent(unittest.TestCase): """Cross-backend consistency for silut with different thresholds.""" diff --git a/source/tests/consistent/test_learning_rate.py b/source/tests/consistent/test_learning_rate.py index 56142a675d..48aea94326 100644 --- a/source/tests/consistent/test_learning_rate.py +++ b/source/tests/consistent/test_learning_rate.py @@ -19,7 +19,7 @@ INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF2, - parameterized, + parameterized_cases, ) if INSTALLED_PT: @@ -35,7 +35,7 @@ import array_api_strict as xp -@parameterized( +LEARNING_RATE_CURATED_CASES = ( ( { "type": "exp", @@ -45,6 +45,8 @@ "num_steps": 1000000, "warmup_steps": 10000, }, + ), + ( { "type": "cosine", "start_lr": 1e-3, @@ -52,6 +54,8 @@ "num_steps": 1000000, "warmup_steps": 10000, }, + ), + ( { "type": "wsd", "start_lr": 1e-3, @@ -60,6 +64,8 @@ "warmup_steps": 10000, "decay_phase_ratio": 0.1, }, + ), + ( { "type": "wsd", "start_lr": 1e-3, @@ -69,6 +75,8 @@ "decay_phase_ratio": 0.1, "decay_type": "cosine", }, + ), + ( { "type": "wsd", "start_lr": 1e-3, @@ -80,6 +88,9 @@ }, ), ) + + +@parameterized_cases(*LEARNING_RATE_CURATED_CASES) class TestLearningRateConsistent(unittest.TestCase): """Test learning rate consistency across different array backends.""" diff --git a/source/tests/consistent/test_type_embedding.py b/source/tests/consistent/test_type_embedding.py index 20853c4a7c..08ebd36040 100644 --- a/source/tests/consistent/test_type_embedding.py +++ b/source/tests/consistent/test_type_embedding.py @@ -22,7 +22,7 @@ INSTALLED_TF, INSTALLED_TF2, CommonTest, - parameterized, + parameterized_cases, ) if INSTALLED_PT: @@ -60,13 +60,49 @@ TypeEmbedNetTF2 = None -@parameterized( - (True, False), # resnet_dt - ("float32", "float64"), # precision - (True, False), # padding - (True, False), # use_econf_tebd - (True, False), # use_tebd_bias +TYPE_EMBEDDING_CASE_FIELDS = ( + "resnet_dt", + "precision", + "padding", + "use_econf_tebd", + "use_tebd_bias", ) + +TYPE_EMBEDDING_BASELINE_CASE = { + "resnet_dt": True, + "precision": "float64", + "padding": True, + "use_econf_tebd": True, + "use_tebd_bias": True, +} + + +def type_embedding_case(**overrides: Any) -> tuple: + unknown = set(overrides) - set(TYPE_EMBEDDING_CASE_FIELDS) + if unknown: + raise ValueError(f"Unknown type-embedding case fields: {sorted(unknown)}") + case = TYPE_EMBEDDING_BASELINE_CASE | overrides + return tuple(case[field] for field in TYPE_EMBEDDING_CASE_FIELDS) + + +TYPE_EMBEDDING_CURATED_CASES = ( + type_embedding_case(), + type_embedding_case(resnet_dt=False), + type_embedding_case(precision="float32"), + type_embedding_case(padding=False), + type_embedding_case(use_econf_tebd=False), + type_embedding_case(use_tebd_bias=False), + type_embedding_case( + resnet_dt=False, + precision="float32", + padding=False, + use_econf_tebd=False, + use_tebd_bias=False, + ), +) + + +@parameterized_cases(*TYPE_EMBEDDING_CURATED_CASES) class TestTypeEmbedding(CommonTest, unittest.TestCase): """Useful utilities for descriptor tests."""