Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 71 additions & 38 deletions source/tests/consistent/descriptor/test_se_atten_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
INSTALLED_PT_EXPT,
INSTALLED_TF2,
CommonTest,
parameterized,
parameterized_cases,
)
from .common import (
DescriptorAPITest,
Expand Down Expand Up @@ -70,26 +70,77 @@
descrpt_se_atten_args,
)

SE_ATTEN_V2_CASE_FIELDS = (
"tebd_dim",
"resnet_dt",
"type_one_side",
"attn",
"attn_layer",
"attn_dotr",
"excluded_types",
"env_protection",
"set_davg_zero",
"scaling_factor",
"normalize",
"temperature",
"ln_eps",
"concat_output_tebd",
"precision",
"use_econf_tebd",
"use_tebd_bias",
)

@parameterized(
(4,), # tebd_dim
(True,), # resnet_dt
(True, False), # type_one_side
(20,), # attn
(0, 2), # attn_layer
(True, False), # attn_dotr
([], [[0, 1]]), # excluded_types
(0.0,), # env_protection
(True, False), # set_davg_zero
(1.0,), # scaling_factor
(True, False), # normalize
(None, 1.0), # temperature
(1e-5,), # ln_eps
(True,), # concat_output_tebd
("float64",), # precision
(True, False), # use_econf_tebd
(False,), # use_tebd_bias
SE_ATTEN_V2_BASELINE_CASE = {
"tebd_dim": 4,
"resnet_dt": True,
"type_one_side": True,
"attn": 20,
"attn_layer": 2,
"attn_dotr": True,
"excluded_types": [],
"env_protection": 0.0,
"set_davg_zero": True,
"scaling_factor": 1.0,
"normalize": True,
"temperature": None,
"ln_eps": 1e-5,
"concat_output_tebd": True,
"precision": "float64",
"use_econf_tebd": True,
"use_tebd_bias": False,
}


def se_atten_v2_case(**overrides: Any) -> tuple:
case = SE_ATTEN_V2_BASELINE_CASE | overrides
return tuple(case[field] for field in SE_ATTEN_V2_CASE_FIELDS)


SE_ATTEN_V2_CURATED_CASES = (
se_atten_v2_case(),
se_atten_v2_case(type_one_side=False),
se_atten_v2_case(attn_layer=0, attn_dotr=False, normalize=False),
se_atten_v2_case(attn_dotr=False),
se_atten_v2_case(excluded_types=[[0, 1]]),
se_atten_v2_case(set_davg_zero=False),
se_atten_v2_case(normalize=False),
se_atten_v2_case(temperature=1.0),
se_atten_v2_case(use_econf_tebd=False),
se_atten_v2_case(
type_one_side=False,
attn_dotr=False,
excluded_types=[[0, 1]],
set_davg_zero=False,
normalize=False,
temperature=1.0,
use_econf_tebd=False,
),
)

SE_ATTEN_V2_DESCRIPTOR_API_CURATED_CASES = SE_ATTEN_V2_CURATED_CASES


@parameterized_cases(*SE_ATTEN_V2_CURATED_CASES)
class TestSeAttenV2(CommonTest, DescriptorTest, unittest.TestCase):
@property
def data(self) -> dict:
Expand Down Expand Up @@ -543,25 +594,7 @@ def atol(self) -> float:
raise ValueError(f"Unknown precision: {precision}")


@parameterized(
(4,), # tebd_dim
(True,), # resnet_dt
(True, False), # type_one_side
(20,), # attn
(0, 2), # attn_layer
(True, False), # attn_dotr
([], [[0, 1]]), # excluded_types
(0.0,), # env_protection
(True, False), # set_davg_zero
(1.0,), # scaling_factor
(True, False), # normalize
(None, 1.0), # temperature
(1e-5,), # ln_eps
(True,), # concat_output_tebd
("float64",), # precision
(True, False), # use_econf_tebd
(False,), # use_tebd_bias
)
@parameterized_cases(*SE_ATTEN_V2_DESCRIPTOR_API_CURATED_CASES)
class TestSeAttenV2DescriptorAPI(DescriptorAPITest, unittest.TestCase):
"""Test consistency of BaseDescriptor API methods across backends."""

Expand Down
50 changes: 38 additions & 12 deletions source/tests/consistent/descriptor/test_se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
INSTALLED_TF,
INSTALLED_TF2,
CommonTest,
parameterized,
parameterized_cases,
)
from .common import (
DescriptorAPITest,
Expand Down Expand Up @@ -58,12 +58,43 @@
DescrptSeRArrayAPIStrict = None


@parameterized(
(True, False), # resnet_dt
(True, False), # type_one_side
([], [[0, 1]]), # excluded_types
("float32", "float64"), # precision
SE_R_CASE_FIELDS = (
"resnet_dt",
"type_one_side",
"excluded_types",
"precision",
)

SE_R_BASELINE_CASE = {
"resnet_dt": True,
"type_one_side": True,
"excluded_types": [],
"precision": "float64",
}


def se_r_case(**overrides: Any) -> tuple:
case = SE_R_BASELINE_CASE | overrides
return tuple(case[field] for field in SE_R_CASE_FIELDS)


SE_R_CURATED_CASES = (
se_r_case(),
se_r_case(resnet_dt=False),
se_r_case(type_one_side=False),
se_r_case(excluded_types=[[0, 1]]),
se_r_case(precision="float32"),
)

SE_R_DESCRIPTOR_API_CURATED_CASES = (
se_r_case(),
se_r_case(resnet_dt=False),
se_r_case(type_one_side=False),
se_r_case(excluded_types=[[0, 1]]),
)


@parameterized_cases(*SE_R_CURATED_CASES)
class TestSeR(CommonTest, DescriptorTest, unittest.TestCase):
@property
def data(self) -> dict:
Expand Down Expand Up @@ -289,12 +320,7 @@ def atol(self) -> float:
raise ValueError(f"Unknown precision: {precision}")


@parameterized(
(True, False), # resnet_dt
(True, False), # type_one_side
([], [[0, 1]]), # excluded_types
("float64",), # precision
)
@parameterized_cases(*SE_R_DESCRIPTOR_API_CURATED_CASES)
class TestSeRDescriptorAPI(DescriptorAPITest, unittest.TestCase):
"""Test consistency of BaseDescriptor API methods across backends."""

Expand Down
51 changes: 39 additions & 12 deletions source/tests/consistent/descriptor/test_se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
INSTALLED_TF,
INSTALLED_TF2,
CommonTest,
parameterized,
parameterized_cases,
)
from .common import (
DescriptorAPITest,
Expand Down Expand Up @@ -54,13 +54,45 @@
descrpt_se_t_args,
)

SE_T_CASE_FIELDS = (
"resnet_dt",
"excluded_types",
"precision",
"env_protection",
)

SE_T_BASELINE_CASE = {
"resnet_dt": True,
"excluded_types": [],
"precision": "float64",
"env_protection": 0.0,
}


def se_t_case(**overrides: Any) -> tuple:
case = SE_T_BASELINE_CASE | overrides
return tuple(case[field] for field in SE_T_CASE_FIELDS)


@parameterized(
(True, False), # resnet_dt
([], [[0, 1]]), # excluded_types
("float32", "float64"), # precision
(0.0, 1e-8, 1e-2), # env_protection
SE_T_CURATED_CASES = (
se_t_case(),
se_t_case(resnet_dt=False),
se_t_case(excluded_types=[[0, 1]]),
se_t_case(precision="float32"),
se_t_case(env_protection=1e-8),
se_t_case(env_protection=1e-2),
)

SE_T_DESCRIPTOR_API_CURATED_CASES = (
se_t_case(),
se_t_case(resnet_dt=False),
se_t_case(excluded_types=[[0, 1]]),
se_t_case(env_protection=1e-8),
se_t_case(env_protection=1e-2),
)


@parameterized_cases(*SE_T_CURATED_CASES)
class TestSeT(CommonTest, DescriptorTest, unittest.TestCase):
@property
def data(self) -> dict:
Expand Down Expand Up @@ -280,12 +312,7 @@ def atol(self) -> float:
raise ValueError(f"Unknown precision: {precision}")


@parameterized(
(True, False), # resnet_dt
([], [[0, 1]]), # excluded_types
("float64",), # precision
(0.0, 1e-8, 1e-2), # env_protection
)
@parameterized_cases(*SE_T_DESCRIPTOR_API_CURATED_CASES)
class TestSeTDescriptorAPI(DescriptorAPITest, unittest.TestCase):
"""Test consistency of BaseDescriptor API methods across backends."""

Expand Down
80 changes: 54 additions & 26 deletions source/tests/consistent/descriptor/test_se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
INSTALLED_PT_EXPT,
INSTALLED_TF2,
CommonTest,
parameterized,
parameterized_cases,
)
from .common import (
DescriptorAPITest,
Expand Down Expand Up @@ -62,20 +62,60 @@
descrpt_se_e3_tebd_args,
)

SE_T_TEBD_CASE_FIELDS = (
"tebd_dim",
"tebd_input_mode",
"resnet_dt",
"excluded_types",
"env_protection",
"set_davg_zero",
"smooth",
"concat_output_tebd",
"precision",
"use_econf_tebd",
"use_tebd_bias",
)

SE_T_TEBD_BASELINE_CASE = {
"tebd_dim": 4,
"tebd_input_mode": "strip",
"resnet_dt": True,
"excluded_types": [],
"env_protection": 0.0,
"set_davg_zero": True,
"smooth": True,
"concat_output_tebd": True,
"precision": "float64",
"use_econf_tebd": True,
"use_tebd_bias": False,
}


def se_t_tebd_case(**overrides: Any) -> tuple:
case = SE_T_TEBD_BASELINE_CASE | overrides
return tuple(case[field] for field in SE_T_TEBD_CASE_FIELDS)


@parameterized(
(4,), # tebd_dim
("strip",), # tebd_input_mode
(True,), # resnet_dt
([], [[0, 1]]), # excluded_types
(0.0,), # env_protection
(True, False), # set_davg_zero
(True, False), # smooth
(True,), # concat_output_tebd
("float64",), # precision
(True, False), # use_econf_tebd
(False, True), # use_tebd_bias
SE_T_TEBD_CURATED_CASES = (
se_t_tebd_case(),
se_t_tebd_case(excluded_types=[[0, 1]]),
se_t_tebd_case(set_davg_zero=False),
se_t_tebd_case(smooth=False),
se_t_tebd_case(use_econf_tebd=False),
se_t_tebd_case(use_tebd_bias=True),
se_t_tebd_case(
excluded_types=[[0, 1]],
set_davg_zero=False,
smooth=False,
use_econf_tebd=False,
use_tebd_bias=True,
),
)

SE_T_TEBD_DESCRIPTOR_API_CURATED_CASES = SE_T_TEBD_CURATED_CASES


@parameterized_cases(*SE_T_TEBD_CURATED_CASES)
class TestSeTTebd(CommonTest, DescriptorTest, unittest.TestCase):
@property
def data(self) -> dict:
Expand Down Expand Up @@ -374,19 +414,7 @@ def atol(self) -> float:
raise ValueError(f"Unknown precision: {precision}")


@parameterized(
(4,), # tebd_dim
("strip",), # tebd_input_mode
(True,), # resnet_dt
([], [[0, 1]]), # excluded_types
(0.0,), # env_protection
(True, False), # set_davg_zero
(True, False), # smooth
(True,), # concat_output_tebd
("float64",), # precision
(True, False), # use_econf_tebd
(False, True), # use_tebd_bias
)
@parameterized_cases(*SE_T_TEBD_DESCRIPTOR_API_CURATED_CASES)
class TestSeTTebdDescriptorAPI(DescriptorAPITest, unittest.TestCase):
"""Test consistency of BaseDescriptor API methods across backends."""

Expand Down
Loading
Loading