Skip to content
72 changes: 61 additions & 11 deletions source/tests/consistent/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import itertools
import os
import re
import sys
import unittest
from abc import (
Expand Down Expand Up @@ -75,7 +76,9 @@
"INSTALLED_PT_EXPT",
"INSTALLED_TF",
"CommonTest",
"CommonTest",
"parameterize_func",
"parameterized",
"parameterized_cases",
]

SKIP_FLAG = object()
Expand Down Expand Up @@ -670,6 +673,46 @@ def tearDown(self) -> None:
clear_session()


def _parameterized_with_cases(full_parameterized: list[tuple]) -> Callable:
def decorator(base_class: type):
class_module = sys.modules[base_class.__module__].__dict__
used_names: set[str] = set()
for pp in full_parameterized:

class TestClass(base_class):
param: ClassVar = pp

# generate a safe name for the class
parts = []
for x in pp:
s = str(x)
# replace non-alnum with underscore, collapse multiple underscores
s = re.sub(r"[^a-zA-Z0-9_]", "_", s)
s = re.sub(r"_+", "_", s)
# remove leading/trailing underscores
s = s.strip("_")
if s == "":
s = "empty"
parts.append(s)
base_name = f"{base_class.__name__}_{'_'.join(parts)}"
name = base_name
suffix = 1
while name in used_names or name in class_module:
name = f"{base_name}_{suffix}"
suffix += 1

TestClass.__name__ = name
TestClass.__qualname__ = name
TestClass.__module__ = base_class.__module__

used_names.add(name)
class_module[name] = TestClass
# make unittest module happy by ignoring the original one
return object

return decorator


def parameterized(*attrs: tuple, **subblock_attrs: tuple) -> Callable:
"""Parameterized test.

Expand Down Expand Up @@ -716,21 +759,28 @@ def parameterized(*attrs: tuple, **subblock_attrs: tuple) -> Callable:
else []
)
full_parameterized = global_combine + block_combine
return _parameterized_with_cases(full_parameterized)

def decorator(base_class: type):
class_module = sys.modules[base_class.__module__].__dict__
for pp in full_parameterized:

class TestClass(base_class):
param: ClassVar = pp
def parameterized_cases(*cases: tuple) -> Callable:
"""Parameterized test with explicit case tuples.

name = f"{base_class.__name__}_{'_'.join(str(x) for x in pp)}"
This variant behaves like :func:`parameterized` but takes a curated list of
case tuples directly instead of computing their Cartesian product.

class_module[name] = TestClass
# make unittest module happy by ignoring the original one
return object
Parameters
----------
*cases : tuple
Explicit case tuples.

return decorator
Returns
-------
object
The decorator.
"""
if not cases:
raise ValueError("parameterized_cases requires at least one case tuple")
return _parameterized_with_cases(list(cases))


def parameterize_func(
Expand Down
112 changes: 70 additions & 42 deletions source/tests/consistent/descriptor/test_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
INSTALLED_PT_EXPT,
INSTALLED_TF,
CommonTest,
parameterized,
parameterized_cases,
)
from .common import (
DescriptorAPITest,
Expand Down Expand Up @@ -57,28 +57,76 @@
descrpt_se_atten_args,
)

DPA1_CASE_FIELDS = (
"tebd_dim",
"tebd_input_mode",
"resnet_dt",
"type_one_side",
"attn",
"attn_layer",
"attn_dotr",
"excluded_types",
"env_protection",
"set_davg_zero",
"scaling_factor",
"normalize",
"temperature",
"ln_eps",
"smooth_type_embedding",
"concat_output_tebd",
"precision",
"use_econf_tebd",
"use_tebd_bias",
)


@parameterized(
(4,), # tebd_dim
("concat", "strip"), # tebd_input_mode
(True,), # resnet_dt
(True,), # type_one_side
(20,), # attn
(0, 2), # attn_layer
(True,), # attn_dotr
([], [[0, 1]]), # excluded_types
(0.0,), # env_protection
(True, False), # set_davg_zero
(1.0,), # scaling_factor
(True,), # normalize
(None, 1.0), # temperature
(1e-5,), # ln_eps
(True,), # smooth_type_embedding
(True,), # concat_output_tebd
("float64",), # precision
(True, False), # use_econf_tebd
(False,), # use_tebd_bias
DPA1_BASELINE_CASE = {
"tebd_dim": 4,
"tebd_input_mode": "concat",
"resnet_dt": True,
"type_one_side": True,
"attn": 20,
"attn_layer": 2,
"attn_dotr": True,
"excluded_types": [],
"env_protection": 0.0,
"set_davg_zero": True,
"scaling_factor": 1.0,
"normalize": True,
"temperature": 1.0,
"ln_eps": 1e-5,
"smooth_type_embedding": True,
"concat_output_tebd": True,
"precision": "float64",
"use_econf_tebd": False,
"use_tebd_bias": False,
}


def dpa1_case(**overrides: Any) -> tuple:
case = DPA1_BASELINE_CASE | overrides
return tuple(case[field] for field in DPA1_CASE_FIELDS)


DPA1_CURATED_CASES = (
# Baseline coverage.
dpa1_case(),
# Alternate tebd input plumbing.
dpa1_case(tebd_input_mode="strip"),
# High-risk descriptor toggles.
dpa1_case(excluded_types=[[0, 1]]),
dpa1_case(set_davg_zero=False),
dpa1_case(normalize=False),
# Attention edge cases: disabled temperature path vs zero-layer path.
dpa1_case(temperature=None),
dpa1_case(attn_layer=0, temperature=None),
# econf-specific path with both tebd input modes.
dpa1_case(use_econf_tebd=True),
dpa1_case(tebd_input_mode="strip", use_econf_tebd=True),
)


@parameterized_cases(*DPA1_CURATED_CASES)
class TestDPA1(CommonTest, DescriptorTest, unittest.TestCase):
@property
def data(self) -> dict:
Expand Down Expand Up @@ -556,27 +604,7 @@ def atol(self) -> float:
raise ValueError(f"Unknown precision: {precision}")


@parameterized(
(4,), # tebd_dim
("concat", "strip"), # tebd_input_mode
(True,), # resnet_dt
(True,), # type_one_side
(20,), # attn
(0, 2), # attn_layer
(True,), # attn_dotr
([], [[0, 1]]), # excluded_types
(0.0,), # env_protection
(True, False), # set_davg_zero
(1.0,), # scaling_factor
(True,), # normalize
(None, 1.0), # temperature
(1e-5,), # ln_eps
(True,), # smooth_type_embedding
(True,), # concat_output_tebd
("float64",), # precision
(True, False), # use_econf_tebd
(False,), # use_tebd_bias
)
@parameterized_cases(*DPA1_CURATED_CASES)
class TestDPA1DescriptorAPI(DescriptorAPITest, unittest.TestCase):
"""Test consistency of BaseDescriptor API methods across backends."""

Expand Down
Loading
Loading