diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index ed4e2caab9..d8fb9e64e0 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -2,6 +2,7 @@ import inspect import itertools import os +import re import sys import unittest from abc import ( @@ -75,7 +76,9 @@ "INSTALLED_PT_EXPT", "INSTALLED_TF", "CommonTest", - "CommonTest", + "parameterize_func", + "parameterized", + "parameterized_cases", ] SKIP_FLAG = object() @@ -670,6 +673,46 @@ def tearDown(self) -> None: clear_session() +def _parameterized_with_cases(full_parameterized: list[tuple]) -> Callable: + def decorator(base_class: type): + class_module = sys.modules[base_class.__module__].__dict__ + used_names: set[str] = set() + for pp in full_parameterized: + + class TestClass(base_class): + param: ClassVar = pp + + # generate a safe name for the class + parts = [] + for x in pp: + s = str(x) + # replace non-alnum with underscore, collapse multiple underscores + s = re.sub(r"[^a-zA-Z0-9_]", "_", s) + s = re.sub(r"_+", "_", s) + # remove leading/trailing underscores + s = s.strip("_") + if s == "": + s = "empty" + parts.append(s) + base_name = f"{base_class.__name__}_{'_'.join(parts)}" + name = base_name + suffix = 1 + while name in used_names or name in class_module: + name = f"{base_name}_{suffix}" + suffix += 1 + + TestClass.__name__ = name + TestClass.__qualname__ = name + TestClass.__module__ = base_class.__module__ + + used_names.add(name) + class_module[name] = TestClass + # make unittest module happy by ignoring the original one + return object + + return decorator + + def parameterized(*attrs: tuple, **subblock_attrs: tuple) -> Callable: """Parameterized test. @@ -733,6 +776,27 @@ class TestClass(base_class): return decorator +def parameterized_cases(*cases: tuple) -> Callable: + """Parameterized test with explicit case tuples. + + This variant behaves like :func:`parameterized` but takes a curated list of + case tuples directly instead of computing their Cartesian product. + + Parameters + ---------- + *cases : tuple + Explicit case tuples. + + Returns + ------- + object + The decorator. + """ + if not cases: + raise ValueError("parameterized_cases requires at least one case tuple") + return _parameterized_with_cases(list(cases)) + + def parameterize_func( func: Callable, param_dict_list: dict[str, tuple], diff --git a/source/tests/consistent/descriptor/test_dpa1.py b/source/tests/consistent/descriptor/test_dpa1.py index 60748bd17b..d1e8a97497 100644 --- a/source/tests/consistent/descriptor/test_dpa1.py +++ b/source/tests/consistent/descriptor/test_dpa1.py @@ -22,7 +22,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( DescriptorAPITest, @@ -57,28 +57,76 @@ descrpt_se_atten_args, ) +DPA1_CASE_FIELDS = ( + "tebd_dim", + "tebd_input_mode", + "resnet_dt", + "type_one_side", + "attn", + "attn_layer", + "attn_dotr", + "excluded_types", + "env_protection", + "set_davg_zero", + "scaling_factor", + "normalize", + "temperature", + "ln_eps", + "smooth_type_embedding", + "concat_output_tebd", + "precision", + "use_econf_tebd", + "use_tebd_bias", +) + -@parameterized( - (4,), # tebd_dim - ("concat", "strip"), # tebd_input_mode - (True,), # resnet_dt - (True,), # type_one_side - (20,), # attn - (0, 2), # attn_layer - (True,), # attn_dotr - ([], [[0, 1]]), # excluded_types - (0.0,), # env_protection - (True, False), # set_davg_zero - (1.0,), # scaling_factor - (True,), # normalize - (None, 1.0), # temperature - (1e-5,), # ln_eps - (True,), # smooth_type_embedding - (True,), # concat_output_tebd - ("float64",), # precision - (True, False), # use_econf_tebd - (False,), # use_tebd_bias +DPA1_BASELINE_CASE = { + "tebd_dim": 4, + "tebd_input_mode": "concat", + "resnet_dt": True, + "type_one_side": True, + "attn": 20, + "attn_layer": 2, + "attn_dotr": True, + "excluded_types": [], + "env_protection": 0.0, + "set_davg_zero": True, + "scaling_factor": 1.0, + "normalize": True, + "temperature": 1.0, + "ln_eps": 1e-5, + "smooth_type_embedding": True, + "concat_output_tebd": True, + "precision": "float64", + "use_econf_tebd": False, + "use_tebd_bias": False, +} + + +def dpa1_case(**overrides: Any) -> tuple: + case = DPA1_BASELINE_CASE | overrides + return tuple(case[field] for field in DPA1_CASE_FIELDS) + + +DPA1_CURATED_CASES = ( + # Baseline coverage. + dpa1_case(), + # Alternate tebd input plumbing. + dpa1_case(tebd_input_mode="strip"), + # High-risk descriptor toggles. + dpa1_case(excluded_types=[[0, 1]]), + dpa1_case(set_davg_zero=False), + dpa1_case(normalize=False), + # Attention edge cases: disabled temperature path vs zero-layer path. + dpa1_case(temperature=None), + dpa1_case(attn_layer=0, temperature=None), + # econf-specific path with both tebd input modes. + dpa1_case(use_econf_tebd=True), + dpa1_case(tebd_input_mode="strip", use_econf_tebd=True), ) + + +@parameterized_cases(*DPA1_CURATED_CASES) class TestDPA1(CommonTest, DescriptorTest, unittest.TestCase): @property def data(self) -> dict: @@ -556,27 +604,7 @@ def atol(self) -> float: raise ValueError(f"Unknown precision: {precision}") -@parameterized( - (4,), # tebd_dim - ("concat", "strip"), # tebd_input_mode - (True,), # resnet_dt - (True,), # type_one_side - (20,), # attn - (0, 2), # attn_layer - (True,), # attn_dotr - ([], [[0, 1]]), # excluded_types - (0.0,), # env_protection - (True, False), # set_davg_zero - (1.0,), # scaling_factor - (True,), # normalize - (None, 1.0), # temperature - (1e-5,), # ln_eps - (True,), # smooth_type_embedding - (True,), # concat_output_tebd - ("float64",), # precision - (True, False), # use_econf_tebd - (False,), # use_tebd_bias -) +@parameterized_cases(*DPA1_CURATED_CASES) class TestDPA1DescriptorAPI(DescriptorAPITest, unittest.TestCase): """Test consistency of BaseDescriptor API methods across backends.""" diff --git a/source/tests/consistent/descriptor/test_dpa2.py b/source/tests/consistent/descriptor/test_dpa2.py index 9a48b84d9d..fe4540e240 100644 --- a/source/tests/consistent/descriptor/test_dpa2.py +++ b/source/tests/consistent/descriptor/test_dpa2.py @@ -21,7 +21,7 @@ INSTALLED_PT, INSTALLED_PT_EXPT, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( DescriptorAPITest, @@ -63,36 +63,106 @@ descrpt_dpa2_args, ) +DPA2_CASE_FIELDS = ( + "repinit_tebd_input_mode", + "repinit_set_davg_zero", + "repinit_type_one_side", + "repinit_use_three_body", + "repformer_direct_dist", + "repformer_update_g1_has_conv", + "repformer_update_g1_has_drrd", + "repformer_update_g1_has_grrg", + "repformer_update_g1_has_attn", + "repformer_update_g2_has_g1g1", + "repformer_update_g2_has_attn", + "repformer_update_h2", + "repformer_attn2_has_gate", + "repformer_update_style", + "repformer_update_residual_init", + "repformer_set_davg_zero", + "repformer_trainable_ln", + "repformer_ln_eps", + "repformer_use_sqrt_nnei", + "repformer_g1_out_conv", + "repformer_g1_out_mlp", + "smooth", + "exclude_types", + "precision", + "add_tebd_to_repinit_out", + "use_econf_tebd", + "use_tebd_bias", +) + -@parameterized( - ("concat", "strip"), # repinit_tebd_input_mode - (True,), # repinit_set_davg_zero - (False,), # repinit_type_one_side - (True, False), # repinit_use_three_body - (True, False), # repformer_direct_dist - (True,), # repformer_update_g1_has_conv - (True,), # repformer_update_g1_has_drrd - (True,), # repformer_update_g1_has_grrg - (True,), # repformer_update_g1_has_attn - (True,), # repformer_update_g2_has_g1g1 - (True,), # repformer_update_g2_has_attn - (False,), # repformer_update_h2 - (True,), # repformer_attn2_has_gate - ("res_avg", "res_residual"), # repformer_update_style - ("norm", "const"), # repformer_update_residual_init - (True,), # repformer_set_davg_zero - (True,), # repformer_trainable_ln - (1e-5,), # repformer_ln_eps - (True,), # repformer_use_sqrt_nnei - (True,), # repformer_g1_out_conv - (True,), # repformer_g1_out_mlp - (True, False), # smooth - ([], [[0, 1]]), # exclude_types - ("float64",), # precision - (True, False), # add_tebd_to_repinit_out - (True, False), # use_econf_tebd - (False,), # use_tebd_bias +DPA2_BASELINE_CASE = { + "repinit_tebd_input_mode": "concat", + "repinit_set_davg_zero": True, + "repinit_type_one_side": False, + "repinit_use_three_body": True, + "repformer_direct_dist": True, + "repformer_update_g1_has_conv": True, + "repformer_update_g1_has_drrd": True, + "repformer_update_g1_has_grrg": True, + "repformer_update_g1_has_attn": True, + "repformer_update_g2_has_g1g1": True, + "repformer_update_g2_has_attn": True, + "repformer_update_h2": False, + "repformer_attn2_has_gate": True, + "repformer_update_style": "res_avg", + "repformer_update_residual_init": "norm", + "repformer_set_davg_zero": True, + "repformer_trainable_ln": True, + "repformer_ln_eps": 1e-5, + "repformer_use_sqrt_nnei": True, + "repformer_g1_out_conv": True, + "repformer_g1_out_mlp": True, + "smooth": True, + "exclude_types": [], + "precision": "float64", + "add_tebd_to_repinit_out": True, + "use_econf_tebd": False, + "use_tebd_bias": False, +} + + +def dpa2_case(**overrides: Any) -> tuple: + case = DPA2_BASELINE_CASE | overrides + return tuple(case[field] for field in DPA2_CASE_FIELDS) + + +DPA2_CURATED_CASES = ( + # Baseline coverage. + dpa2_case(), + # Alternate repinit embedding path. + dpa2_case(repinit_tebd_input_mode="strip"), + # repinit / repformer structural toggles. + dpa2_case(repinit_use_three_body=False), + # Keep direct_dist and update_g1_has_conv named explicitly: a historical + # tuple/unpack mismatch could silently swap these booleans when raw tuples + # were edited, so curated cases must spell out which branch is changing. + dpa2_case(repformer_direct_dist=False), + dpa2_case(repformer_update_style="res_residual"), + dpa2_case(repformer_update_residual_init="const"), + # Descriptor-level toggles. + dpa2_case(smooth=False), + dpa2_case(exclude_types=[[0, 1]]), + dpa2_case(add_tebd_to_repinit_out=False), + # econf-specific coverage, including one mixed high-risk combination. + dpa2_case(use_econf_tebd=True), + dpa2_case( + repinit_tebd_input_mode="strip", + repformer_direct_dist=False, + repformer_update_style="res_residual", + repformer_update_residual_init="const", + smooth=False, + exclude_types=[[0, 1]], + add_tebd_to_repinit_out=False, + use_econf_tebd=True, + ), ) + + +@parameterized_cases(*DPA2_CURATED_CASES) class TestDPA2(CommonTest, DescriptorTest, unittest.TestCase): @property def data(self) -> dict: @@ -101,8 +171,8 @@ def data(self) -> dict: repinit_set_davg_zero, repinit_type_one_side, repinit_use_three_body, - repformer_update_g1_has_conv, repformer_direct_dist, + repformer_update_g1_has_conv, repformer_update_g1_has_drrd, repformer_update_g1_has_grrg, repformer_update_g1_has_attn, @@ -173,7 +243,7 @@ def data(self) -> dict: "update_style": repformer_update_style, "update_residual": 0.001, "update_residual_init": repformer_update_residual_init, - "set_davg_zero": True, + "set_davg_zero": repformer_set_davg_zero, "trainable_ln": repformer_trainable_ln, "ln_eps": repformer_ln_eps, "use_sqrt_nnei": repformer_use_sqrt_nnei, @@ -201,8 +271,8 @@ def skip_pt(self) -> bool: repinit_set_davg_zero, repinit_type_one_side, repinit_use_three_body, - repformer_update_g1_has_conv, repformer_direct_dist, + repformer_update_g1_has_conv, repformer_update_g1_has_drrd, repformer_update_g1_has_grrg, repformer_update_g1_has_attn, @@ -234,8 +304,8 @@ def skip_pd(self) -> bool: repinit_set_davg_zero, repinit_type_one_side, repinit_use_three_body, - repformer_update_g1_has_conv, repformer_direct_dist, + repformer_update_g1_has_conv, repformer_update_g1_has_drrd, repformer_update_g1_has_grrg, repformer_update_g1_has_attn, @@ -267,8 +337,8 @@ def skip_dp(self) -> bool: repinit_set_davg_zero, repinit_type_one_side, repinit_use_three_body, - repformer_update_g1_has_conv, repformer_direct_dist, + repformer_update_g1_has_conv, repformer_update_g1_has_drrd, repformer_update_g1_has_grrg, repformer_update_g1_has_attn, @@ -300,8 +370,8 @@ def skip_tf(self) -> bool: repinit_set_davg_zero, repinit_type_one_side, repinit_use_three_body, - repformer_update_g1_has_conv, repformer_direct_dist, + repformer_update_g1_has_conv, repformer_update_g1_has_drrd, repformer_update_g1_has_grrg, repformer_update_g1_has_attn, @@ -377,8 +447,8 @@ def setUp(self) -> None: repinit_set_davg_zero, repinit_type_one_side, repinit_use_three_body, - repformer_update_g1_has_conv, repformer_direct_dist, + repformer_update_g1_has_conv, repformer_update_g1_has_drrd, repformer_update_g1_has_grrg, repformer_update_g1_has_attn, @@ -483,8 +553,8 @@ def rtol(self) -> float: repinit_set_davg_zero, repinit_type_one_side, repinit_use_three_body, - repformer_update_g1_has_conv, repformer_direct_dist, + repformer_update_g1_has_conv, repformer_update_g1_has_drrd, repformer_update_g1_has_grrg, repformer_update_g1_has_attn, @@ -522,8 +592,8 @@ def atol(self) -> float: repinit_set_davg_zero, repinit_type_one_side, repinit_use_three_body, - repformer_update_g1_has_conv, repformer_direct_dist, + repformer_update_g1_has_conv, repformer_update_g1_has_drrd, repformer_update_g1_has_grrg, repformer_update_g1_has_attn, @@ -554,35 +624,7 @@ def atol(self) -> float: raise ValueError(f"Unknown precision: {precision}") -@parameterized( - ("concat", "strip"), # repinit_tebd_input_mode - (True,), # repinit_set_davg_zero - (False,), # repinit_type_one_side - (True, False), # repinit_use_three_body - (True, False), # repformer_direct_dist - (True,), # repformer_update_g1_has_conv - (True,), # repformer_update_g1_has_drrd - (True,), # repformer_update_g1_has_grrg - (True,), # repformer_update_g1_has_attn - (True,), # repformer_update_g2_has_g1g1 - (True,), # repformer_update_g2_has_attn - (False,), # repformer_update_h2 - (True,), # repformer_attn2_has_gate - ("res_avg", "res_residual"), # repformer_update_style - ("norm", "const"), # repformer_update_residual_init - (True,), # repformer_set_davg_zero - (True,), # repformer_trainable_ln - (1e-5,), # repformer_ln_eps - (True,), # repformer_use_sqrt_nnei - (True,), # repformer_g1_out_conv - (True,), # repformer_g1_out_mlp - (True, False), # smooth - ([], [[0, 1]]), # exclude_types - ("float64",), # precision - (True, False), # add_tebd_to_repinit_out - (True, False), # use_econf_tebd - (False,), # use_tebd_bias -) +@parameterized_cases(*DPA2_CURATED_CASES) class TestDPA2DescriptorAPI(DescriptorAPITest, unittest.TestCase): """Test consistency of BaseDescriptor API methods across backends.""" @@ -598,8 +640,8 @@ def data(self) -> dict: repinit_set_davg_zero, repinit_type_one_side, repinit_use_three_body, - repformer_update_g1_has_conv, repformer_direct_dist, + repformer_update_g1_has_conv, repformer_update_g1_has_drrd, repformer_update_g1_has_grrg, repformer_update_g1_has_attn, @@ -670,7 +712,7 @@ def data(self) -> dict: "update_style": repformer_update_style, "update_residual": 0.001, "update_residual_init": repformer_update_residual_init, - "set_davg_zero": True, + "set_davg_zero": repformer_set_davg_zero, "trainable_ln": repformer_trainable_ln, "ln_eps": repformer_ln_eps, "use_sqrt_nnei": repformer_use_sqrt_nnei, diff --git a/source/tests/consistent/descriptor/test_se_e2_a.py b/source/tests/consistent/descriptor/test_se_e2_a.py index efa5c4bc6d..2174c5edda 100644 --- a/source/tests/consistent/descriptor/test_se_e2_a.py +++ b/source/tests/consistent/descriptor/test_se_e2_a.py @@ -19,7 +19,7 @@ INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( DescriptorAPITest, @@ -71,13 +71,55 @@ DescrptSeAArrayAPIStrict = None -@parameterized( - (True, False), # resnet_dt - (True, False), # type_one_side - ([], [[0, 1]]), # excluded_types - ("float32", "float64"), # precision - (0.0, 1e-8, 1e-2), # env_protection +SE_A_CASE_FIELDS = ( + "resnet_dt", + "type_one_side", + "excluded_types", + "precision", + "env_protection", ) + + +SE_A_BASELINE_CASE = { + "resnet_dt": True, + "type_one_side": True, + "excluded_types": [], + "precision": "float64", + "env_protection": 0.0, +} + + +def se_a_case(**overrides: Any) -> tuple: + case = SE_A_BASELINE_CASE | overrides + return tuple(case[field] for field in SE_A_CASE_FIELDS) + + +SE_A_CURATED_CASES = ( + # Baseline coverage. + se_a_case(), + # Core descriptor toggles. + se_a_case(resnet_dt=False), + se_a_case(type_one_side=False), + se_a_case(excluded_types=[[0, 1]]), + # Environment-protection edge cases. + se_a_case(env_protection=1e-8), + se_a_case(env_protection=1e-2), + # Lower-precision smoke coverage. + se_a_case(precision="float32"), +) + +SE_A_DESCRIPTOR_API_CASES = ( + # Descriptor API coverage keeps float64-only behavior from the original test. + se_a_case(), + se_a_case(resnet_dt=False), + se_a_case(type_one_side=False), + se_a_case(excluded_types=[[0, 1]]), + se_a_case(env_protection=1e-8), + se_a_case(env_protection=1e-2), +) + + +@parameterized_cases(*SE_A_CURATED_CASES) class TestSeA(CommonTest, DescriptorTest, unittest.TestCase): @property def data(self) -> dict: @@ -337,13 +379,17 @@ def atol(self) -> float: raise ValueError(f"Unknown precision: {precision}") -@parameterized( - (True,), # resnet_dt - (True,), # type_one_side - ([],), # excluded_types - ("float64",), # precision - (0.0, 1e-8, 1e-2), # env_protection +SE_A_STAT_CASES = ( + # Statistics path exercises the float64 baseline, a type_one_side=False + # variant, and env-protection variants. + se_a_case(), + se_a_case(type_one_side=False), + se_a_case(env_protection=1e-8), + se_a_case(env_protection=1e-2), ) + + +@parameterized_cases(*SE_A_STAT_CASES) class TestSeAStat(CommonTest, DescriptorTest, unittest.TestCase): @property def data(self) -> dict: @@ -669,13 +715,7 @@ def atol(self) -> float: raise ValueError(f"Unknown precision: {precision}") -@parameterized( - (True, False), # resnet_dt - (True, False), # type_one_side - ([], [[0, 1]]), # excluded_types - ("float64",), # precision - (0.0, 1e-8, 1e-2), # env_protection -) +@parameterized_cases(*SE_A_DESCRIPTOR_API_CASES) class TestSeADescriptorAPI(DescriptorAPITest, unittest.TestCase): """Test consistency of BaseDescriptor API methods across backends."""