diff --git a/deepmd/dpmodel/infer/deep_eval.py b/deepmd/dpmodel/infer/deep_eval.py index ac6963b435..d567ece0e1 100644 --- a/deepmd/dpmodel/infer/deep_eval.py +++ b/deepmd/dpmodel/infer/deep_eval.py @@ -406,6 +406,10 @@ def get_model_def_script(self) -> dict: """Get model definition script.""" return json.loads(self.dp.get_model_def_script()) + def serialize(self) -> dict[str, Any]: + model = self.dp + return model.serialize() + def get_observed_types(self) -> dict: """Get observed types (elements) of the model during data statistics. diff --git a/deepmd/entrypoints/show.py b/deepmd/entrypoints/show.py index 7fd3e81467..2cf1b881bb 100644 --- a/deepmd/entrypoints/show.py +++ b/deepmd/entrypoints/show.py @@ -4,6 +4,9 @@ Any, ) +from deepmd.dpmodel.utils.serialization import ( + Node, +) from deepmd.infer.deep_eval import ( DeepEval, ) @@ -136,3 +139,7 @@ def show( observed_types = model.get_observed_types() log.info(f"Number of observed types: {observed_types['type_num']} ") log.info(f"Observed types: {observed_types['observed_type']} ") + + if "serialization-tree" in ATTRIBUTES: + root = Node.deserialize(model.serialize()) + log.info("Model serialization tree:\n" + str(root)) diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py index d375a2ecd7..ae6720349e 100644 --- a/deepmd/infer/deep_eval.py +++ b/deepmd/infer/deep_eval.py @@ -361,6 +361,16 @@ def get_model(self) -> Any: The model module implemented by the deep learning framework. """ + @abstractmethod + def serialize(self) -> dict[str, Any]: + """Serialize the loaded model structure only. + + Returns + ------- + dict + Serialized model tree that can be consumed by ``Node.deserialize``. + """ + class DeepEval(ABC): """High-level Deep Evaluator interface. @@ -404,6 +414,7 @@ def __init__( neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None, **kwargs: Any, ) -> None: + self.model_file = model_file self.deep_eval = DeepEvalBackend( model_file, self.output_def, @@ -420,6 +431,10 @@ def __init__( def output_def(self) -> ModelOutputDef: """Returns the output variable definitions.""" + def serialize(self) -> dict[str, Any]: + """Serialize the loaded model structure only.""" + return self.deep_eval.serialize() + def get_rcut(self) -> float: """Get the cutoff radius of this model.""" return self.deep_eval.get_rcut() diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index 2e028225f7..c2784c3b5f 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -187,6 +187,16 @@ def get_ntypes_spin(self) -> int: """Get the number of spin atom types of this model.""" return 0 + def serialize(self) -> dict[str, Any]: + from deepmd.jax.utils.serialization import ( + serialize_from_file, + ) + + data = serialize_from_file(self.model_path) + if "model" not in data: + raise RuntimeError("Serialized model data does not contain key 'model'.") + return data["model"] + def eval( self, coords: np.ndarray, diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 14386d9f3d..ba6ad31ea0 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -202,5 +202,13 @@ def convert_str_to_int_key(item: dict) -> None: data.pop("constants") data["@variables"].pop("stablehlo") return data + elif model_file.endswith(".savedmodel"): + from deepmd.tf.utils.serialization import ( + serialize_from_file as serialize_savedmodel, + ) + + return serialize_savedmodel(model_file) else: - raise ValueError("JAX backend only supports converting .jax directory") + raise ValueError( + "JAX backend only supports converting .jax directory, .hlo, and .savedmodel" + ) diff --git a/deepmd/main.py b/deepmd/main.py index 3afcda8b4a..6ef2fcd0d2 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -942,6 +942,7 @@ def main_parser() -> argparse.ArgumentParser: "fitting-net", "size", "observed-type", + "serialization-tree", ], nargs="+", ) diff --git a/deepmd/pd/infer/deep_eval.py b/deepmd/pd/infer/deep_eval.py index 6c0ffed7ec..561faca562 100644 --- a/deepmd/pd/infer/deep_eval.py +++ b/deepmd/pd/infer/deep_eval.py @@ -730,6 +730,12 @@ def get_model_def_script(self) -> dict: """Get model definition script.""" return self.model_def_script + def serialize(self) -> dict[str, Any]: + model = ( + self.dp.model["Default"] if isinstance(self.dp, ModelWrapper) else self.dp + ) + return model.serialize() + def get_model_size(self) -> dict: """Get model parameter count. diff --git a/deepmd/pretrained/deep_eval.py b/deepmd/pretrained/deep_eval.py index 2dc671b0cc..aa15a50760 100644 --- a/deepmd/pretrained/deep_eval.py +++ b/deepmd/pretrained/deep_eval.py @@ -184,3 +184,6 @@ def get_ntypes_spin(self) -> int: def get_model(self) -> Any: return self._backend.get_model() + + def serialize(self) -> dict[str, Any]: + return self._backend.serialize() diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 11a877040d..1e4c8a0929 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -702,6 +702,10 @@ def get_model_def_script(self) -> dict: """Get model definition script.""" return self.model_def_script + def serialize(self) -> dict[str, Any]: + model = self.dp.model["Default"] + return model.serialize() + def get_model_size(self) -> dict: """Get model parameter count. diff --git a/deepmd/pt_expt/infer/deep_eval.py b/deepmd/pt_expt/infer/deep_eval.py index a6e1e1e540..992d43e8cd 100644 --- a/deepmd/pt_expt/infer/deep_eval.py +++ b/deepmd/pt_expt/infer/deep_eval.py @@ -665,6 +665,13 @@ def get_model_def_script(self) -> dict: """Get model definition script.""" return self.metadata + def serialize(self) -> dict[str, Any]: + from deepmd.pt_expt.utils.serialization import ( + serialize_from_file, + ) + + return serialize_from_file(self.model_path) + def get_model(self) -> torch.nn.Module: """Get the exported model module. diff --git a/deepmd/tf/infer/deep_eval.py b/deepmd/tf/infer/deep_eval.py index 0ec2f1c74e..0d81ce588f 100644 --- a/deepmd/tf/infer/deep_eval.py +++ b/deepmd/tf/infer/deep_eval.py @@ -110,6 +110,7 @@ def __init__( input_map=input_map, ) self.load_prefix = load_prefix + self.model_file = model_file # graph_compatable should be called after graph and prefix are set if not self._graph_compatable(): @@ -1121,6 +1122,22 @@ def get_model_def_script(self) -> dict: model_def_script = script.decode("utf-8") return json.loads(model_def_script)["model"] + def serialize(self) -> dict[str, Any]: + from deepmd.tf.model.model import ( + Model, + ) + from deepmd.tf.utils.graph import ( + load_graph_def, + ) + + graph, graph_def = load_graph_def(str(self.model_file)) + + model_def_script = self.get_model_def_script() + model = Model(**model_def_script) + # important! must be called before serialize + model.init_variables(graph=graph, graph_def=graph_def) + return model.serialize() + def get_model(self) -> "tf.Graph": """Get the TensorFlow graph. @@ -1172,6 +1189,7 @@ def __init__( input_map=input_map, ) self.load_prefix = load_prefix + self.model_file = model_file # graph_compatable should be called after graph and prefix are set if not self._graph_compatable(): diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index 982d56d8fa..c492e52510 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -156,11 +156,12 @@ def test_deep_eval(self) -> None: if not backend.is_available(): continue reference_data = copy.deepcopy(self.data) - self.save_data_to_model( - prefix + backend.suffixes[suffix_idx], reference_data - ) - deep_eval = DeepEval(prefix + backend.suffixes[suffix_idx]) + model_file = prefix + backend.suffixes[suffix_idx] + self.save_data_to_model(model_file, reference_data) + deep_eval = DeepEval(model_file) self.assertIsInstance(deep_eval.get_model_def_script(), dict) + serialized_data = self.get_data_from_model(model_file) + np.testing.assert_equal(deep_eval.serialize(), serialized_data["model"]) if deep_eval.get_dim_fparam() > 0: fparam = np.ones((nframes, deep_eval.get_dim_fparam())) else: diff --git a/source/tests/pt_expt/infer/test_deep_eval.py b/source/tests/pt_expt/infer/test_deep_eval.py index ef38e1d36f..26d0741f7b 100644 --- a/source/tests/pt_expt/infer/test_deep_eval.py +++ b/source/tests/pt_expt/infer/test_deep_eval.py @@ -109,6 +109,12 @@ def test_get_model_def_script(self) -> None: self.assertAlmostEqual(mds["rcut"], self.rcut) self.assertEqual(mds["sel"], list(self.sel)) + def test_serialize_returns_model_tree(self) -> None: + data = self.dp.deep_eval.serialize() + self.assertEqual(data["@class"], self.model.serialize()["@class"]) + self.assertEqual(data["type"], self.model.serialize()["type"]) + self.assertEqual(data, serialize_from_file(self.tmpfile.name)) + def test_eval_consistency(self) -> None: """Test that DeepPot.eval gives same results as direct model forward.""" rng = np.random.default_rng(GLOBAL_SEED) diff --git a/source/tests/test_entrypoint_show_serialization_tree.py b/source/tests/test_entrypoint_show_serialization_tree.py new file mode 100644 index 0000000000..9f01651115 --- /dev/null +++ b/source/tests/test_entrypoint_show_serialization_tree.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from unittest.mock import ( + patch, +) + +from deepmd.entrypoints.show import ( + show, +) + + +class TestShowSerializationTree(unittest.TestCase): + def test_serialization_tree_uses_deep_eval_model_payload(self) -> None: + with ( + patch("deepmd.entrypoints.show.DeepEval") as mock_deep_eval, + patch("deepmd.entrypoints.show.Node.deserialize") as mock_deserialize, + patch("deepmd.entrypoints.show.log.info") as mock_log_info, + ): + model = mock_deep_eval.return_value + model.get_model_def_script.return_value = {"type_map": ["H", "O"]} + model.get_model_size.return_value = {} + model.serialize.return_value = {"@class": "MockModel"} + mock_deserialize.return_value = "ROOT" + + show(INPUT="mock.pte", ATTRIBUTES=["serialization-tree"]) + + model.serialize.assert_called_once_with() + mock_deserialize.assert_called_once_with({"@class": "MockModel"}) + mock_log_info.assert_any_call("Model serialization tree:\nROOT")