diff --git a/source/tests/infer/fparam_aparam-testcase.yaml b/source/tests/infer/fparam_aparam-testcase.yaml index 220b2df209..1f300e31dd 100644 --- a/source/tests/infer/fparam_aparam-testcase.yaml +++ b/source/tests/infer/fparam_aparam-testcase.yaml @@ -26,7 +26,7 @@ model_def_script: "set_davg_zero": False, "trainable": True, "type": "se_e2_a", - "type_one_side": False, + "type_one_side": True, }, "fitting_net": { diff --git a/source/tests/infer/fparam_aparam.yaml b/source/tests/infer/fparam_aparam.yaml index e0654e142f..3f22bbcf6c 100644 --- a/source/tests/infer/fparam_aparam.yaml +++ b/source/tests/infer/fparam_aparam.yaml @@ -526,7 +526,7 @@ model: embeddings: "@class": NetworkCollection "@version": 1 - ndim: 2 + ndim: 1 network_type: embedding_network networks: - "@class": EmbeddingNetwork @@ -916,7 +916,7 @@ model: type: se_e2_a type_map: &id001 - O - type_one_side: false + type_one_side: true fitting: "@class": Fitting "@variables": @@ -2012,7 +2012,7 @@ model_def_script: set_davg_zero: false trainable: true type: se_e2_a - type_one_side: false + type_one_side: true fitting_net: activation_function: tanh atom_ener: *id004 diff --git a/source/tests/infer/fparam_aparam_default.yaml b/source/tests/infer/fparam_aparam_default.yaml index 6d64bfc328..5798817325 100644 --- a/source/tests/infer/fparam_aparam_default.yaml +++ b/source/tests/infer/fparam_aparam_default.yaml @@ -526,7 +526,7 @@ model: embeddings: "@class": NetworkCollection "@version": 1 - ndim: 2 + ndim: 1 network_type: embedding_network networks: - "@class": EmbeddingNetwork @@ -916,7 +916,7 @@ model: type: se_e2_a type_map: &id001 - O - type_one_side: false + type_one_side: true fitting: "@class": Fitting "@variables": @@ -2021,7 +2021,7 @@ model_def_script: set_davg_zero: false trainable: true type: se_e2_a - type_one_side: false + type_one_side: true fitting_net: activation_function: tanh atom_ener: *id004 diff --git a/source/tests/infer/test_models.py b/source/tests/infer/test_models.py index 7f7b7cc21c..44e3de30cb 100644 --- a/source/tests/infer/test_models.py +++ b/source/tests/infer/test_models.py @@ -13,6 +13,7 @@ ) from ..consistent.common import ( + INSTALLED_PT_EXPT, parameterized, ) from .case import ( @@ -28,7 +29,7 @@ "se_e2_r", "fparam_aparam", ), # key - (".pb", ".pth"), # model extension + (".pb", ".pth", ".pte", ".pt2"), # model extension ) class TestDeepPot(unittest.TestCase): # moved from tests/tf/test_deeppot_a.py @@ -36,21 +37,36 @@ class TestDeepPot(unittest.TestCase): @classmethod def setUpClass(cls) -> None: key, extension = cls.param + if extension in (".pte", ".pt2") and not INSTALLED_PT_EXPT: + raise unittest.SkipTest("pt_expt backend not installed") + if key in ("se_e2_a", "se_e2_r") and extension in (".pte", ".pt2"): + raise unittest.SkipTest( + "type_one_side=False is not supported for pt_expt export" + ) + if key == "se_e2_r" and extension == ".pth": + raise unittest.SkipTest( + "se_e2_r type_one_side is not supported for PyTorch models" + ) cls.case = get_cases()[key] - cls.model_name = cls.case.get_model(extension) + if extension == ".pt2": + import torch + + # Clear default device: tests/pt/__init__.py may set a fake + # device for CPU fallback, which poisons AOTInductor compilation. + saved_device = torch.get_default_device() + torch.set_default_device(None) + try: + cls.model_name = cls.case.get_model(extension) + finally: + torch.set_default_device(saved_device) + else: + cls.model_name = cls.case.get_model(extension) cls.dp = DeepEval(cls.model_name) @classmethod def tearDownClass(cls) -> None: cls.dp = None - def setUp(self) -> None: - key, extension = self.param - if key == "se_e2_r" and extension == ".pth": - self.skipTest( - reason="se_e2_r type_one_side is not supported for PyTorch models" - ) - def test_attrs(self) -> None: assert isinstance(self.dp, DeepPot) self.assertEqual(self.dp.get_ntypes(), self.case.ntypes) @@ -153,6 +169,8 @@ def test_1frame_atm(self) -> None: def test_descriptor(self) -> None: _, extension = self.param + if extension in (".pte", ".pt2"): + self.skipTest("eval_descriptor not supported for pt_expt models") for ii, result in enumerate(self.case.results): if result.descriptor is None: continue @@ -166,8 +184,8 @@ def test_descriptor(self) -> None: def test_fitting_last_layer(self) -> None: _, extension = self.param - if extension == ".pb": - self.skipTest("fitting_last_layer not supported for TensorFlow models") + if extension in (".pb", ".pte", ".pt2"): + self.skipTest("fitting_last_layer not supported for this backend") for ii, result in enumerate(self.case.results): if result.fit_ll is None: continue