Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion source/tests/infer/fparam_aparam-testcase.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ model_def_script:
"set_davg_zero": False,
"trainable": True,
"type": "se_e2_a",
"type_one_side": False,
"type_one_side": True,
},
"fitting_net":
{
Expand Down
6 changes: 3 additions & 3 deletions source/tests/infer/fparam_aparam.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ model:
embeddings:
"@class": NetworkCollection
"@version": 1
ndim: 2
ndim: 1
network_type: embedding_network
networks:
- "@class": EmbeddingNetwork
Expand Down Expand Up @@ -916,7 +916,7 @@ model:
type: se_e2_a
type_map: &id001
- O
type_one_side: false
type_one_side: true
fitting:
"@class": Fitting
"@variables":
Expand Down Expand Up @@ -2012,7 +2012,7 @@ model_def_script:
set_davg_zero: false
trainable: true
type: se_e2_a
type_one_side: false
type_one_side: true
fitting_net:
activation_function: tanh
atom_ener: *id004
Expand Down
6 changes: 3 additions & 3 deletions source/tests/infer/fparam_aparam_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ model:
embeddings:
"@class": NetworkCollection
"@version": 1
ndim: 2
ndim: 1
network_type: embedding_network
networks:
- "@class": EmbeddingNetwork
Expand Down Expand Up @@ -916,7 +916,7 @@ model:
type: se_e2_a
type_map: &id001
- O
type_one_side: false
type_one_side: true
fitting:
"@class": Fitting
"@variables":
Expand Down Expand Up @@ -2021,7 +2021,7 @@ model_def_script:
set_davg_zero: false
trainable: true
type: se_e2_a
type_one_side: false
type_one_side: true
fitting_net:
activation_function: tanh
atom_ener: *id004
Expand Down
40 changes: 29 additions & 11 deletions source/tests/infer/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)

from ..consistent.common import (
INSTALLED_PT_EXPT,
parameterized,
)
from .case import (
Expand All @@ -28,29 +29,44 @@
"se_e2_r",
"fparam_aparam",
), # key
(".pb", ".pth"), # model extension
(".pb", ".pth", ".pte", ".pt2"), # model extension
)
class TestDeepPot(unittest.TestCase):
# moved from tests/tf/test_deeppot_a.py

@classmethod
def setUpClass(cls) -> None:
key, extension = cls.param
if extension in (".pte", ".pt2") and not INSTALLED_PT_EXPT:
raise unittest.SkipTest("pt_expt backend not installed")
if key in ("se_e2_a", "se_e2_r") and extension in (".pte", ".pt2"):
raise unittest.SkipTest(
"type_one_side=False is not supported for pt_expt export"
)
if key == "se_e2_r" and extension == ".pth":
raise unittest.SkipTest(
"se_e2_r type_one_side is not supported for PyTorch models"
)
cls.case = get_cases()[key]
cls.model_name = cls.case.get_model(extension)
if extension == ".pt2":
import torch

# Clear default device: tests/pt/__init__.py may set a fake
# device for CPU fallback, which poisons AOTInductor compilation.
saved_device = torch.get_default_device()
torch.set_default_device(None)
try:
cls.model_name = cls.case.get_model(extension)
finally:
torch.set_default_device(saved_device)
else:
cls.model_name = cls.case.get_model(extension)
cls.dp = DeepEval(cls.model_name)

@classmethod
def tearDownClass(cls) -> None:
cls.dp = None

def setUp(self) -> None:
key, extension = self.param
if key == "se_e2_r" and extension == ".pth":
self.skipTest(
reason="se_e2_r type_one_side is not supported for PyTorch models"
)

def test_attrs(self) -> None:
assert isinstance(self.dp, DeepPot)
self.assertEqual(self.dp.get_ntypes(), self.case.ntypes)
Expand Down Expand Up @@ -153,6 +169,8 @@ def test_1frame_atm(self) -> None:

def test_descriptor(self) -> None:
_, extension = self.param
if extension in (".pte", ".pt2"):
self.skipTest("eval_descriptor not supported for pt_expt models")
for ii, result in enumerate(self.case.results):
if result.descriptor is None:
continue
Expand All @@ -166,8 +184,8 @@ def test_descriptor(self) -> None:

def test_fitting_last_layer(self) -> None:
_, extension = self.param
if extension == ".pb":
self.skipTest("fitting_last_layer not supported for TensorFlow models")
if extension in (".pb", ".pte", ".pt2"):
self.skipTest("fitting_last_layer not supported for this backend")
for ii, result in enumerate(self.case.results):
if result.fit_ll is None:
continue
Expand Down
Loading