Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,10 @@ def get_model_def_script(self) -> dict:
"""Get model definition script."""
return json.loads(self.dp.get_model_def_script())

def serialize(self) -> dict[str, Any]:
model = self.dp
return model.serialize()

def get_observed_types(self) -> dict:
"""Get observed types (elements) of the model during data statistics.

Expand Down
7 changes: 7 additions & 0 deletions deepmd/entrypoints/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
Any,
)

from deepmd.dpmodel.utils.serialization import (
Node,
)
from deepmd.infer.deep_eval import (
DeepEval,
)
Expand Down Expand Up @@ -136,3 +139,7 @@ def show(
observed_types = model.get_observed_types()
log.info(f"Number of observed types: {observed_types['type_num']} ")
log.info(f"Observed types: {observed_types['observed_type']} ")

if "serialization-tree" in ATTRIBUTES:
root = Node.deserialize(model.serialize())
log.info("Model serialization tree:\n" + str(root))
15 changes: 15 additions & 0 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,16 @@ def get_model(self) -> Any:
The model module implemented by the deep learning framework.
"""

@abstractmethod
def serialize(self) -> dict[str, Any]:
"""Serialize the loaded model structure only.

Returns
-------
dict
Serialized model tree that can be consumed by ``Node.deserialize``.
"""


class DeepEval(ABC):
"""High-level Deep Evaluator interface.
Expand Down Expand Up @@ -404,6 +414,7 @@ def __init__(
neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
**kwargs: Any,
) -> None:
self.model_file = model_file
self.deep_eval = DeepEvalBackend(
model_file,
self.output_def,
Expand All @@ -420,6 +431,10 @@ def __init__(
def output_def(self) -> ModelOutputDef:
"""Returns the output variable definitions."""

def serialize(self) -> dict[str, Any]:
"""Serialize the loaded model structure only."""
return self.deep_eval.serialize()

def get_rcut(self) -> float:
"""Get the cutoff radius of this model."""
return self.deep_eval.get_rcut()
Expand Down
10 changes: 10 additions & 0 deletions deepmd/jax/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,16 @@ def get_ntypes_spin(self) -> int:
"""Get the number of spin atom types of this model."""
return 0

def serialize(self) -> dict[str, Any]:
from deepmd.jax.utils.serialization import (
serialize_from_file,
)

data = serialize_from_file(self.model_path)
if "model" not in data:
raise RuntimeError("Serialized model data does not contain key 'model'.")
return data["model"]

def eval(
self,
coords: np.ndarray,
Expand Down
1 change: 1 addition & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,7 @@ def main_parser() -> argparse.ArgumentParser:
"fitting-net",
"size",
"observed-type",
"serialization-tree",
],
nargs="+",
)
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,12 @@ def get_model_def_script(self) -> dict:
"""Get model definition script."""
return self.model_def_script

def serialize(self) -> dict[str, Any]:
model = (
self.dp.model["Default"] if isinstance(self.dp, ModelWrapper) else self.dp
)
return model.serialize()

def get_model_size(self) -> dict:
"""Get model parameter count.

Expand Down
3 changes: 3 additions & 0 deletions deepmd/pretrained/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,6 @@ def get_ntypes_spin(self) -> int:

def get_model(self) -> Any:
return self._backend.get_model()

def serialize(self) -> dict[str, Any]:
return self._backend.serialize()
4 changes: 4 additions & 0 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,10 @@ def get_model_def_script(self) -> dict:
"""Get model definition script."""
return self.model_def_script

def serialize(self) -> dict[str, Any]:
model = self.dp.model["Default"]
return model.serialize()

def get_model_size(self) -> dict:
"""Get model parameter count.

Expand Down
7 changes: 7 additions & 0 deletions deepmd/pt_expt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,13 @@ def get_model_def_script(self) -> dict:
"""Get model definition script."""
return self.metadata

def serialize(self) -> dict[str, Any]:
from deepmd.pt_expt.utils.serialization import (
serialize_from_file,
)

return serialize_from_file(self.model_path)

def get_model(self) -> torch.nn.Module:
"""Get the exported model module.

Expand Down
18 changes: 18 additions & 0 deletions deepmd/tf/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(
input_map=input_map,
)
self.load_prefix = load_prefix
self.model_file = model_file

# graph_compatable should be called after graph and prefix are set
if not self._graph_compatable():
Expand Down Expand Up @@ -1121,6 +1122,22 @@ def get_model_def_script(self) -> dict:
model_def_script = script.decode("utf-8")
return json.loads(model_def_script)["model"]

def serialize(self) -> dict[str, Any]:
from deepmd.tf.model.model import (
Model,
)
from deepmd.tf.utils.graph import (
load_graph_def,
)

graph, graph_def = load_graph_def(str(self.model_file))

model_def_script = self.get_model_def_script()
model = Model(**model_def_script)
# important! must be called before serialize
model.init_variables(graph=graph, graph_def=graph_def)
return model.serialize()

def get_model(self) -> "tf.Graph":
"""Get the TensorFlow graph.

Expand Down Expand Up @@ -1172,6 +1189,7 @@ def __init__(
input_map=input_map,
)
self.load_prefix = load_prefix
self.model_file = model_file

# graph_compatable should be called after graph and prefix are set
if not self._graph_compatable():
Expand Down
9 changes: 5 additions & 4 deletions source/tests/consistent/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,12 @@ def test_deep_eval(self) -> None:
if not backend.is_available():
continue
reference_data = copy.deepcopy(self.data)
self.save_data_to_model(
prefix + backend.suffixes[suffix_idx], reference_data
)
deep_eval = DeepEval(prefix + backend.suffixes[suffix_idx])
model_file = prefix + backend.suffixes[suffix_idx]
self.save_data_to_model(model_file, reference_data)
deep_eval = DeepEval(model_file)
self.assertIsInstance(deep_eval.get_model_def_script(), dict)
serialized_data = self.get_data_from_model(model_file)
np.testing.assert_equal(deep_eval.serialize(), serialized_data["model"])
if deep_eval.get_dim_fparam() > 0:
fparam = np.ones((nframes, deep_eval.get_dim_fparam()))
else:
Expand Down
6 changes: 6 additions & 0 deletions source/tests/pt_expt/infer/test_deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ def test_get_model_def_script(self) -> None:
self.assertAlmostEqual(mds["rcut"], self.rcut)
self.assertEqual(mds["sel"], list(self.sel))

def test_serialize_returns_model_tree(self) -> None:
data = self.dp.deep_eval.serialize()
self.assertEqual(data["@class"], self.model.serialize()["@class"])
self.assertEqual(data["type"], self.model.serialize()["type"])
self.assertEqual(data, serialize_from_file(self.tmpfile.name))

def test_eval_consistency(self) -> None:
"""Test that DeepPot.eval gives same results as direct model forward."""
rng = np.random.default_rng(GLOBAL_SEED)
Expand Down
29 changes: 29 additions & 0 deletions source/tests/test_entrypoint_show_serialization_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest
from unittest.mock import (
patch,
)

from deepmd.entrypoints.show import (
show,
)


class TestShowSerializationTree(unittest.TestCase):
def test_serialization_tree_uses_deep_eval_model_payload(self) -> None:
with (
patch("deepmd.entrypoints.show.DeepEval") as mock_deep_eval,
patch("deepmd.entrypoints.show.Node.deserialize") as mock_deserialize,
patch("deepmd.entrypoints.show.log.info") as mock_log_info,
):
model = mock_deep_eval.return_value
model.get_model_def_script.return_value = {"type_map": ["H", "O"]}
model.get_model_size.return_value = {}
model.serialize.return_value = {"@class": "MockModel"}
mock_deserialize.return_value = "ROOT"

show(INPUT="mock.pte", ATTRIBUTES=["serialization-tree"])

model.serialize.assert_called_once_with()
mock_deserialize.assert_called_once_with({"@class": "MockModel"})
mock_log_info.assert_any_call("Model serialization tree:\nROOT")
Loading