diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py index 5f5aea50e5..88f56213ee 100644 --- a/deepmd/dpmodel/descriptor/dpa3.py +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -170,6 +170,12 @@ class RepFlowArgs: In the dynamic selection case, neighbor-scale normalization will use `e_sel / sel_reduce_factor` or `a_sel / sel_reduce_factor` instead of the raw `e_sel` or `a_sel` values, accommodating larger selection numbers. + sequential_update : bool, optional + Whether to use sequential update mode within each repflow layer. + When True, updates are applied sequentially: edge self → angle self (using updated edge) + → edge angle (using updated angle) → node (using final edge), + instead of the default parallel mode where all updates use original embeddings. + Currently only supports ``update_style='res_residual'``. """ def __init__( @@ -201,6 +207,7 @@ def __init__( use_exp_switch: bool = False, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, + sequential_update: bool = False, ) -> None: self.n_dim = n_dim self.e_dim = e_dim @@ -231,6 +238,15 @@ def __init__( self.use_exp_switch = use_exp_switch self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor + self.sequential_update = sequential_update + if self.sequential_update: + if self.update_style != "res_residual": + raise ValueError( + "sequential_update only supports update_style='res_residual', " + f"got '{self.update_style}'!" + ) + if not self.update_angle: + raise ValueError("sequential_update requires update_angle=True!") def __getitem__(self, key: str) -> Any: if hasattr(self, key): @@ -266,6 +282,7 @@ def serialize(self) -> dict: "use_exp_switch": self.use_exp_switch, "use_dynamic_sel": self.use_dynamic_sel, "sel_reduce_factor": self.sel_reduce_factor, + "sequential_update": self.sequential_update, } @classmethod @@ -404,6 +421,7 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any: use_exp_switch=self.repflow_args.use_exp_switch, use_dynamic_sel=self.repflow_args.use_dynamic_sel, sel_reduce_factor=self.repflow_args.sel_reduce_factor, + sequential_update=self.repflow_args.sequential_update, use_loc_mapping=use_loc_mapping, exclude_types=exclude_types, env_protection=env_protection, diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index 30637dc75a..5788840279 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -230,6 +230,7 @@ def __init__( use_exp_switch: bool = False, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, + sequential_update: bool = False, use_loc_mapping: bool = True, seed: int | list[int] | None = None, trainable: bool = True, @@ -268,6 +269,7 @@ def __init__( self.use_dynamic_sel = use_dynamic_sel self.use_loc_mapping = use_loc_mapping self.sel_reduce_factor = sel_reduce_factor + self.sequential_update = sequential_update if self.use_dynamic_sel and not self.smooth_edge_update: raise NotImplementedError( "smooth_edge_update must be True when use_dynamic_sel is True!" @@ -339,6 +341,7 @@ def __init__( optim_update=self.optim_update, use_dynamic_sel=self.use_dynamic_sel, sel_reduce_factor=self.sel_reduce_factor, + sequential_update=self.sequential_update, smooth_edge_update=self.smooth_edge_update, seed=child_seed(child_seed(seed, 1), ii), trainable=trainable, @@ -757,6 +760,7 @@ def serialize(self) -> dict: "smooth_edge_update": self.smooth_edge_update, "use_dynamic_sel": self.use_dynamic_sel, "sel_reduce_factor": self.sel_reduce_factor, + "sequential_update": self.sequential_update, "use_loc_mapping": self.use_loc_mapping, # variables "edge_embd": self.edge_embd.serialize(), @@ -905,6 +909,7 @@ def __init__( optim_update: bool = True, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, + sequential_update: bool = False, smooth_edge_update: bool = False, activation_function: str = "silu", update_style: str = "res_residual", @@ -954,8 +959,15 @@ def __init__( self.smooth_edge_update = smooth_edge_update self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor + self.sequential_update = sequential_update self.dynamic_e_sel = self.nnei / self.sel_reduce_factor self.dynamic_a_sel = self.a_sel / self.sel_reduce_factor + if self.sequential_update and self.update_style != "res_residual": + raise NotImplementedError( + "sequential_update only supports update_style='res_residual'!" + ) + if self.sequential_update and not self.update_angle: + raise NotImplementedError("sequential_update requires update_angle=True!") assert update_residual_init in [ "norm", @@ -1342,6 +1354,418 @@ def optim_edge_update_dynamic( result_update = bias + sub_node_update + sub_edge_update + sub_node_ext_update return result_update + def _call_sequential( + self, + xp: object, + node_ebd: Array, + node_ebd_ext: Array, + edge_ebd: Array, + h2: Array, + angle_ebd: Array, + nlist: Array, + nlist_mask: Array, + sw: Array, + a_nlist_mask: Array, + a_sw: Array, + nei_node_ebd: Array, + n2e_index: Array, + n_ext2e_index: Array, + n2a_index: Array, + eij2a_index: Array, + eik2a_index: Array, + nb: int, + nloc: int, + nnei: int, + nall: int, + n_edge: int, + ) -> tuple[Array, Array, Array]: + """Sequential update path: edge_self → angle_self → edge_angle → node. + + Only supports update_style='res_residual'. + """ + assert self.angle_self_linear is not None + assert self.edge_angle_linear1 is not None + assert self.edge_angle_linear2 is not None + + # ==================================================================== + # Phase 1: Edge self update (uses original node_ebd, edge_ebd) + # ==================================================================== + if not self.optim_update: + if not self.use_dynamic_sel: + edge_info = xp.concat( + [ + xp.tile( + xp.reshape(node_ebd, (nb, nloc, 1, self.n_dim)), + (1, 1, self.nnei, 1), + ), + nei_node_ebd, + edge_ebd, + ], + axis=-1, + ) + else: + edge_info = xp.concat( + [ + xp.take( + xp.reshape(node_ebd, (-1, self.n_dim)), + n2e_index, + axis=0, + ), + nei_node_ebd, + edge_ebd, + ], + axis=-1, + ) + edge_self_update = self.act(self.edge_self_linear(edge_info)) + else: + edge_self_update = self.act( + self.optim_edge_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nlist, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_edge_update_dynamic( + node_ebd, + node_ebd_ext, + edge_ebd, + n2e_index, + n_ext2e_index, + "edge", + ) + ) + + # Apply edge self residual + edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update + + # ==================================================================== + # Phase 2: Angle self update (uses original node_ebd, updated edge_ebd_s1) + # ==================================================================== + if self.a_compress_rate != 0: + if not self.a_compress_use_split: + assert self.a_compress_n_linear is not None + assert self.a_compress_e_linear is not None + node_ebd_for_angle = self.a_compress_n_linear(node_ebd) + edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd_s1) + else: + node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] + edge_ebd_for_angle = edge_ebd_s1[..., : self.e_a_compress_dim] + else: + node_ebd_for_angle = node_ebd + edge_ebd_for_angle = edge_ebd_s1 + + if not self.use_dynamic_sel: + edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] + edge_ebd_for_angle = xp.where( + xp.expand_dims(a_nlist_mask, axis=-1), + edge_ebd_for_angle, + xp.zeros_like(edge_ebd_for_angle), + ) + + if not self.optim_update: + node_for_angle_info = ( + xp.tile( + xp.reshape( + node_ebd_for_angle, (nb, nloc, 1, 1, self.n_a_compress_dim) + ), + (1, 1, self.a_sel, self.a_sel, 1), + ) + if not self.use_dynamic_sel + else xp.take( + xp.reshape(node_ebd_for_angle, (-1, self.n_a_compress_dim)), + n2a_index, + axis=0, + ) + ) + edge_for_angle_k = ( + xp.tile( + xp.reshape( + edge_ebd_for_angle, + (nb, nloc, 1, self.a_sel, self.e_a_compress_dim), + ), + (1, 1, self.a_sel, 1, 1), + ) + if not self.use_dynamic_sel + else xp.take( + edge_ebd_for_angle, + eik2a_index, + axis=0, + ) + ) + edge_for_angle_j = ( + xp.tile( + xp.reshape( + edge_ebd_for_angle, + (nb, nloc, self.a_sel, 1, self.e_a_compress_dim), + ), + (1, 1, 1, self.a_sel, 1), + ) + if not self.use_dynamic_sel + else xp.take( + edge_ebd_for_angle, + eij2a_index, + axis=0, + ) + ) + edge_for_angle_info = xp.concat( + [edge_for_angle_k, edge_for_angle_j], axis=-1 + ) + angle_info = xp.concat( + [angle_ebd, node_for_angle_info, edge_for_angle_info], axis=-1 + ) + angle_self_update = self.act(self.angle_self_linear(angle_info)) + else: + angle_self_update = self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "angle", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "angle", + ) + ) + + # Apply angle self residual + a_updated = angle_ebd + self.a_residual[0] * angle_self_update + + # ==================================================================== + # Phase 3: Edge angle update (uses updated angle a_updated, updated edge_ebd_s1) + # ==================================================================== + if not self.optim_update: + angle_info_s2 = xp.concat( + [a_updated, node_for_angle_info, edge_for_angle_info], axis=-1 + ) + edge_angle_update = self.act(self.edge_angle_linear1(angle_info_s2)) + else: + edge_angle_update = self.act( + self.optim_angle_update( + a_updated, + node_ebd_for_angle, + edge_ebd_for_angle, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + a_updated, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "edge", + ) + ) + + # Reduce edge angle update over angle dimension + if not self.use_dynamic_sel: + weighted_edge_angle_update = ( + a_sw[:, :, :, xp.newaxis, xp.newaxis] + * a_sw[:, :, xp.newaxis, :, xp.newaxis] + * edge_angle_update + ) + reduced_edge_angle_update = xp.sum(weighted_edge_angle_update, axis=-2) / ( + self.a_sel**0.5 + ) + padding_edge_angle_update = xp.concat( + [ + reduced_edge_angle_update, + xp.zeros( + (nb, nloc, self.nnei - self.a_sel, self.e_dim), + dtype=edge_ebd.dtype, + device=array_api_compat.device(edge_ebd), + ), + ], + axis=2, + ) + else: + weighted_edge_angle_update = edge_angle_update * xp.expand_dims( + a_sw, axis=-1 + ) + padding_edge_angle_update = aggregate( + weighted_edge_angle_update, + eij2a_index, + average=False, + num_owner=n_edge, + ) / (self.dynamic_a_sel**0.5) + + if not self.smooth_edge_update: + if self.use_dynamic_sel: + raise NotImplementedError( + "smooth_edge_update must be True when use_dynamic_sel is True!" + ) + full_mask = xp.concat( + [ + a_nlist_mask, + xp.zeros( + (nb, nloc, self.nnei - self.a_sel), + dtype=a_nlist_mask.dtype, + device=array_api_compat.device(a_nlist_mask), + ), + ], + axis=-1, + ) + padding_edge_angle_update = xp.where( + xp.expand_dims(full_mask, axis=-1), + padding_edge_angle_update, + edge_ebd, + ) + + edge_angle_processed = self.act( + self.edge_angle_linear2(padding_edge_angle_update) + ) + + # Apply edge angle residual on top of edge_ebd_s1 + e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_processed + + # ==================================================================== + # Phase 4: Node edge message (uses e_updated) + # ==================================================================== + if not self.optim_update: + if not self.use_dynamic_sel: + edge_info_updated = xp.concat( + [ + xp.tile( + xp.reshape(node_ebd, (nb, nloc, 1, self.n_dim)), + (1, 1, self.nnei, 1), + ), + nei_node_ebd, + e_updated, + ], + axis=-1, + ) + else: + edge_info_updated = xp.concat( + [ + xp.take( + xp.reshape(node_ebd, (-1, self.n_dim)), + n2e_index, + axis=0, + ), + nei_node_ebd, + e_updated, + ], + axis=-1, + ) + node_edge_update = self.act( + self.node_edge_linear(edge_info_updated) + ) * xp.expand_dims(sw, axis=-1) + else: + node_edge_update = self.act( + self.optim_edge_update( + node_ebd, + node_ebd_ext, + e_updated, + nlist, + "node", + ) + if not self.use_dynamic_sel + else self.optim_edge_update_dynamic( + node_ebd, + node_ebd_ext, + e_updated, + n2e_index, + n_ext2e_index, + "node", + ) + ) * xp.expand_dims(sw, axis=-1) + + node_edge_update = ( + (xp.sum(node_edge_update, axis=-2) / self.nnei) + if not self.use_dynamic_sel + else ( + xp.reshape( + aggregate( + node_edge_update, + n2e_index, + average=False, + num_owner=nb * nloc, + ), + (nb, nloc, node_edge_update.shape[-1]), + ) + / self.dynamic_e_sel + ) + ) + + # ==================================================================== + # Phase 5: Node updates (node_self, node_sym with e_updated, node_edge) + # ==================================================================== + n_update_list: list[Array] = [node_ebd] + + # node self mlp + node_self_mlp = self.act(self.node_self_mlp(node_ebd)) + n_update_list.append(node_self_mlp) + + # node sym using e_updated + node_sym_list: list[Array] = [] + node_sym_list.append( + symmetrization_op( + e_updated, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + if not self.use_dynamic_sel + else symmetrization_op_dynamic( + e_updated, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) + ) + node_sym_list.append( + symmetrization_op( + nei_node_ebd, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + if not self.use_dynamic_sel + else symmetrization_op_dynamic( + nei_node_ebd, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) + ) + node_sym = self.act(self.node_sym_linear(xp.concat(node_sym_list, axis=-1))) + n_update_list.append(node_sym) + + if self.n_multi_edge_message > 1: + node_edge_update_mul_head = xp.reshape( + node_edge_update, (nb, nloc, self.n_multi_edge_message, self.n_dim) + ) + for head_index in range(self.n_multi_edge_message): + n_update_list.append(node_edge_update_mul_head[:, :, head_index, :]) + else: + n_update_list.append(node_edge_update) + + n_updated = self.list_update(n_update_list, "node") + + return n_updated, e_updated, a_updated + def call( self, node_ebd_ext: Array, # nf x nall x n_dim @@ -1446,6 +1870,32 @@ def call( ) ) + if self.sequential_update and self.update_angle: + return self._call_sequential( + xp, + node_ebd, + node_ebd_ext, + edge_ebd, + h2, + angle_ebd, + nlist, + nlist_mask, + sw, + a_nlist_mask, + a_sw, + nei_node_ebd, + n2e_index, + n_ext2e_index, + n2a_index, + eij2a_index, + eik2a_index, + nb, + nloc, + nnei, + nall, + n_edge, + ) + n_update_list: list[Array] = [node_ebd] e_update_list: list[Array] = [edge_ebd] a_update_list: list[Array] = [angle_ebd] @@ -1907,6 +2357,7 @@ def serialize(self) -> dict: "smooth_edge_update": self.smooth_edge_update, "use_dynamic_sel": self.use_dynamic_sel, "sel_reduce_factor": self.sel_reduce_factor, + "sequential_update": self.sequential_update, "node_self_mlp": self.node_self_mlp.serialize(), "node_sym_linear": self.node_sym_linear.serialize(), "node_edge_linear": self.node_edge_linear.serialize(), diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py index 0c6982afe5..a5f79280fa 100644 --- a/deepmd/pt/model/descriptor/dpa3.py +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -167,6 +167,7 @@ def init_subclass_params(sub_data: Any, sub_class: Any) -> Any: use_exp_switch=self.repflow_args.use_exp_switch, use_dynamic_sel=self.repflow_args.use_dynamic_sel, sel_reduce_factor=self.repflow_args.sel_reduce_factor, + sequential_update=self.repflow_args.sequential_update, use_loc_mapping=use_loc_mapping, exclude_types=exclude_types, env_protection=env_protection, diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index 338f48b060..57c9368839 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -53,6 +53,7 @@ def __init__( optim_update: bool = True, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, + sequential_update: bool = False, smooth_edge_update: bool = False, activation_function: str = "silu", update_style: str = "res_residual", @@ -102,8 +103,15 @@ def __init__( self.smooth_edge_update = smooth_edge_update self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor + self.sequential_update = sequential_update self.dynamic_e_sel = self.nnei / self.sel_reduce_factor self.dynamic_a_sel = self.a_sel / self.sel_reduce_factor + if self.sequential_update and self.update_style != "res_residual": + raise NotImplementedError( + "sequential_update only supports update_style='res_residual'!" + ) + if self.sequential_update and not self.update_angle: + raise NotImplementedError("sequential_update requires update_angle=True!") assert update_residual_init in [ "norm", @@ -694,6 +702,383 @@ def optim_edge_update_dynamic( result_update = bias + sub_node_update + sub_edge_update + sub_node_ext_update return result_update + def _forward_sequential( + self, + node_ebd: torch.Tensor, + node_ebd_ext: torch.Tensor, + edge_ebd: torch.Tensor, + h2: torch.Tensor, + angle_ebd: torch.Tensor, + nlist: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + a_nlist_mask: torch.Tensor, + a_sw: torch.Tensor, + nei_node_ebd: torch.Tensor, + n2e_index: torch.Tensor, + n_ext2e_index: torch.Tensor, + n2a_index: torch.Tensor, + eij2a_index: torch.Tensor, + eik2a_index: torch.Tensor, + nb: int, + nloc: int, + nnei: int, + nall: int, + n_edge: int | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Sequential update path: edge_self → angle_self → edge_angle → node. + + Only supports update_style='res_residual'. + """ + assert self.edge_angle_linear1 is not None + assert self.edge_angle_linear2 is not None + assert self.angle_self_linear is not None + + # ==================================================================== + # Phase 1: Edge self update (uses original node_ebd, edge_ebd) + # ==================================================================== + if not self.optim_update: + if not self.use_dynamic_sel: + edge_info = torch.cat( + [ + torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), + nei_node_ebd, + edge_ebd, + ], + dim=-1, + ) + else: + edge_info = torch.cat( + [ + torch.index_select( + node_ebd.reshape(-1, self.n_dim), 0, n2e_index + ), + nei_node_ebd, + edge_ebd, + ], + dim=-1, + ) + edge_self_update = self.act(self.edge_self_linear(edge_info)) + else: + edge_self_update = self.act( + self.optim_edge_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nlist, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_edge_update_dynamic( + node_ebd, + node_ebd_ext, + edge_ebd, + n2e_index, + n_ext2e_index, + "edge", + ) + ) + + # Apply edge self residual: edge_ebd_s1 = edge_ebd + e_residual[0] * edge_self_update + edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update + + # ==================================================================== + # Phase 2: Angle self update (uses original node_ebd, updated edge_ebd_s1) + # ==================================================================== + # Prepare edge for angle from edge_ebd_s1 (updated edge) + if self.a_compress_rate != 0: + if not self.a_compress_use_split: + assert self.a_compress_n_linear is not None + assert self.a_compress_e_linear is not None + node_ebd_for_angle = self.a_compress_n_linear(node_ebd) + edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd_s1) + else: + node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] + edge_ebd_for_angle = edge_ebd_s1[..., : self.e_a_compress_dim] + else: + node_ebd_for_angle = node_ebd + edge_ebd_for_angle = edge_ebd_s1 + + if not self.use_dynamic_sel: + edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] + edge_ebd_for_angle = torch.where( + a_nlist_mask.unsqueeze(-1), edge_ebd_for_angle, 0.0 + ) + + # Initialize for JIT: these are only used in non-optim_update path + node_for_angle_info = angle_ebd # placeholder, overwritten below + edge_for_angle_info = angle_ebd # placeholder, overwritten below + + if not self.optim_update: + node_for_angle_info = ( + torch.tile( + node_ebd_for_angle.unsqueeze(2).unsqueeze(2), + (1, 1, self.a_sel, self.a_sel, 1), + ) + if not self.use_dynamic_sel + else torch.index_select( + node_ebd_for_angle.reshape(-1, self.n_a_compress_dim), + 0, + n2a_index, + ) + ) + edge_for_angle_k = ( + torch.tile(edge_ebd_for_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1)) + if not self.use_dynamic_sel + else torch.index_select(edge_ebd_for_angle, 0, eik2a_index) + ) + edge_for_angle_j = ( + torch.tile(edge_ebd_for_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1)) + if not self.use_dynamic_sel + else torch.index_select(edge_ebd_for_angle, 0, eij2a_index) + ) + edge_for_angle_info = torch.cat( + [edge_for_angle_k, edge_for_angle_j], dim=-1 + ) + angle_info = torch.cat( + [angle_ebd, node_for_angle_info, edge_for_angle_info], dim=-1 + ) + angle_self_update = self.act(self.angle_self_linear(angle_info)) + else: + angle_self_update = self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "angle", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "angle", + ) + ) + + # Apply angle self residual: angle_ebd_s2 = angle_ebd + a_residual[0] * angle_self_update + a_updated = angle_ebd + self.a_residual[0] * angle_self_update + + # ==================================================================== + # Phase 3: Edge angle update (uses updated angle_ebd_s2, updated edge_ebd_s1) + # ==================================================================== + if not self.optim_update: + # Rebuild angle_info with updated angle (a_updated) + angle_info_s2 = torch.cat( + [a_updated, node_for_angle_info, edge_for_angle_info], dim=-1 + ) + edge_angle_update = self.act(self.edge_angle_linear1(angle_info_s2)) + else: + edge_angle_update = self.act( + self.optim_angle_update( + a_updated, + node_ebd_for_angle, + edge_ebd_for_angle, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + a_updated, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "edge", + ) + ) + + # Reduce edge angle update over angle dimension + if not self.use_dynamic_sel: + weighted_edge_angle_update = ( + a_sw.unsqueeze(-1).unsqueeze(-1) + * a_sw.unsqueeze(-2).unsqueeze(-1) + * edge_angle_update + ) + reduced_edge_angle_update = torch.sum( + weighted_edge_angle_update, dim=-2 + ) / (self.a_sel**0.5) + padding_edge_angle_update = torch.concat( + [ + reduced_edge_angle_update, + torch.zeros( + [nb, nloc, self.nnei - self.a_sel, self.e_dim], + dtype=edge_ebd.dtype, + device=edge_ebd.device, + ), + ], + dim=2, + ) + else: + assert n_edge is not None + weighted_edge_angle_update = edge_angle_update * a_sw.unsqueeze(-1) + padding_edge_angle_update = aggregate( + weighted_edge_angle_update, + eij2a_index, + average=False, + num_owner=n_edge, + ) / (self.dynamic_a_sel**0.5) + + if not self.smooth_edge_update: + if self.use_dynamic_sel: + raise NotImplementedError( + "smooth_edge_update must be True when use_dynamic_sel is True!" + ) + full_mask = torch.concat( + [ + a_nlist_mask, + torch.zeros( + [nb, nloc, self.nnei - self.a_sel], + dtype=a_nlist_mask.dtype, + device=a_nlist_mask.device, + ), + ], + dim=-1, + ) + padding_edge_angle_update = torch.where( + full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd + ) + + edge_angle_processed = self.act( + self.edge_angle_linear2(padding_edge_angle_update) + ) + + # Apply edge angle residual on top of edge_ebd_s1 (no recomputation) + e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_processed + + # ==================================================================== + # Phase 4: Node edge message (uses e_updated) + # ==================================================================== + if not self.optim_update: + if not self.use_dynamic_sel: + edge_info_updated = torch.cat( + [ + torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), + nei_node_ebd, + e_updated, + ], + dim=-1, + ) + else: + edge_info_updated = torch.cat( + [ + torch.index_select( + node_ebd.reshape(-1, self.n_dim), 0, n2e_index + ), + nei_node_ebd, + e_updated, + ], + dim=-1, + ) + node_edge_update = self.act( + self.node_edge_linear(edge_info_updated) + ) * sw.unsqueeze(-1) + else: + node_edge_update = self.act( + self.optim_edge_update( + node_ebd, + node_ebd_ext, + e_updated, + nlist, + "node", + ) + if not self.use_dynamic_sel + else self.optim_edge_update_dynamic( + node_ebd, + node_ebd_ext, + e_updated, + n2e_index, + n_ext2e_index, + "node", + ) + ) * sw.unsqueeze(-1) + + node_edge_update = ( + (torch.sum(node_edge_update, dim=-2) / self.nnei) + if not self.use_dynamic_sel + else ( + aggregate( + node_edge_update, + n2e_index, + average=False, + num_owner=nb * nloc, + ).reshape(nb, nloc, node_edge_update.shape[-1]) + / self.dynamic_e_sel + ) + ) + + # ==================================================================== + # Phase 5: Node updates (node_self, node_sym with e_updated, node_edge) + # ==================================================================== + n_update_list: list[torch.Tensor] = [node_ebd] + + # node self mlp (uses original node_ebd) + node_self_mlp = self.act(self.node_self_mlp(node_ebd)) + n_update_list.append(node_self_mlp) + + # node sym using e_updated + node_sym_list: list[torch.Tensor] = [] + node_sym_list.append( + self.symmetrization_op( + e_updated, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + if not self.use_dynamic_sel + else self.symmetrization_op_dynamic( + e_updated, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) + ) + node_sym_list.append( + self.symmetrization_op( + nei_node_ebd, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + if not self.use_dynamic_sel + else self.symmetrization_op_dynamic( + nei_node_ebd, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) + ) + node_sym = self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1))) + n_update_list.append(node_sym) + + if self.n_multi_edge_message > 1: + node_edge_update_mul_head = node_edge_update.view( + nb, nloc, self.n_multi_edge_message, self.n_dim + ) + for head_index in range(self.n_multi_edge_message): + n_update_list.append(node_edge_update_mul_head[..., head_index, :]) + else: + n_update_list.append(node_edge_update) + + n_updated = self.list_update(n_update_list, "node") + + return n_updated, e_updated, a_updated + def forward( self, node_ebd_ext: torch.Tensor, # nf x nall x n_dim [OR] nf x nloc x n_dim when not parallel_mode @@ -783,6 +1168,31 @@ def forward( ) ) + if self.sequential_update and self.update_angle: + return self._forward_sequential( + node_ebd, + node_ebd_ext, + edge_ebd, + h2, + angle_ebd, + nlist, + nlist_mask, + sw, + a_nlist_mask, + a_sw, + nei_node_ebd, + n2e_index, + n_ext2e_index, + n2a_index, + eij2a_index, + eik2a_index, + nb, + nloc, + nnei, + nall, + n_edge, + ) + n_update_list: list[torch.Tensor] = [node_ebd] e_update_list: list[torch.Tensor] = [edge_ebd] a_update_list: list[torch.Tensor] = [angle_ebd] @@ -1220,6 +1630,7 @@ def serialize(self) -> dict: "smooth_edge_update": self.smooth_edge_update, "use_dynamic_sel": self.use_dynamic_sel, "sel_reduce_factor": self.sel_reduce_factor, + "sequential_update": self.sequential_update, "node_self_mlp": self.node_self_mlp.serialize(), "node_sym_linear": self.node_sym_linear.serialize(), "node_edge_linear": self.node_edge_linear.serialize(), diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index 433897860f..7c16ab3c7a 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -219,6 +219,7 @@ def __init__( use_exp_switch: bool = False, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, + sequential_update: bool = False, use_loc_mapping: bool = True, optim_update: bool = True, seed: int | list[int] | None = None, @@ -258,6 +259,7 @@ def __init__( self.use_exp_switch = use_exp_switch self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor + self.sequential_update = sequential_update if self.use_dynamic_sel and not self.smooth_edge_update: raise NotImplementedError( "smooth_edge_update must be True when use_dynamic_sel is True!" @@ -329,6 +331,7 @@ def __init__( optim_update=self.optim_update, use_dynamic_sel=self.use_dynamic_sel, sel_reduce_factor=self.sel_reduce_factor, + sequential_update=self.sequential_update, smooth_edge_update=self.smooth_edge_update, seed=child_seed(child_seed(seed, 1), ii), trainable=trainable, diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index b12bc7ef6f..70a7985702 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1550,6 +1550,13 @@ def dpa3_repflow_args() -> list[Argument]: "or `a_sel / sel_reduce_factor` instead of the raw `e_sel` or `a_sel` values, " "accommodating larger selection numbers." ) + doc_sequential_update = ( + "Whether to use sequential update mode within each repflow layer. " + "When True, updates are applied sequentially: edge self → angle self (using updated edge) " + "→ edge angle (using updated angle) → node (using final edge), " + "instead of the default parallel mode where all updates use original embeddings. " + "Currently only supports update_style='res_residual'." + ) return [ # repflow args @@ -1680,6 +1687,13 @@ def dpa3_repflow_args() -> list[Argument]: default=10.0, doc=doc_sel_reduce_factor, ), + Argument( + "sequential_update", + bool, + optional=True, + default=False, + doc=doc_sequential_update, + ), ] diff --git a/source/tests/consistent/descriptor/test_dpa3.py b/source/tests/consistent/descriptor/test_dpa3.py index bca0759f5c..b980c584a1 100644 --- a/source/tests/consistent/descriptor/test_dpa3.py +++ b/source/tests/consistent/descriptor/test_dpa3.py @@ -79,6 +79,7 @@ (1,), # n_multi_edge_message ("float64",), # precision (False, True), # add_chg_spin_ebd + (False, True), # sequential_update ) class TestDPA3(CommonTest, DescriptorTest, unittest.TestCase): @property @@ -99,6 +100,7 @@ def data(self) -> dict: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param return { "ntypes": self.ntypes, @@ -130,6 +132,7 @@ def data(self) -> dict: "update_style": "res_residual", "update_residual": 0.1, "update_residual_init": update_residual_init, + "sequential_update": sequential_update, } ), # kwargs for descriptor @@ -160,6 +163,7 @@ def skip_pt(self) -> bool: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param return CommonTest.skip_pt @@ -181,6 +185,7 @@ def skip_pd(self) -> bool: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param return True if add_chg_spin_ebd else CommonTest.skip_pd @@ -202,6 +207,7 @@ def skip_dp(self) -> bool: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param return CommonTest.skip_dp @@ -223,6 +229,7 @@ def skip_tf(self) -> bool: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param return True @@ -288,6 +295,7 @@ def setUp(self) -> None: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param # fparam for charge=5, spin=1 when add_chg_spin_ebd is True self.fparam = ( @@ -394,6 +402,7 @@ def rtol(self) -> float: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param if precision == "float64": return 1e-10 @@ -421,6 +430,7 @@ def atol(self) -> float: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param if precision == "float64": return 1e-6 # need to fix in the future, see issue https://github.com/deepmodeling/deepmd-kit/issues/3786 diff --git a/source/tests/pt/model/test_dpa3.py b/source/tests/pt/model/test_dpa3.py index 12b0be4532..d66eab9dea 100644 --- a/source/tests/pt/model/test_dpa3.py +++ b/source/tests/pt/model/test_dpa3.py @@ -56,6 +56,7 @@ def test_consistency( prec, ect, add_chg_spin, + seq_upd, ) in itertools.product( [True, False], # update_angle ["res_residual"], # update_style @@ -67,7 +68,11 @@ def test_consistency( ["float64"], # precision [False], # use_econf_tebd [False, True], # add_chg_spin_ebd + [False, True], # sequential_update ): + # sequential_update only works with update_angle=True + if seq_upd and not ua: + continue dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) if prec == "float64": @@ -93,6 +98,7 @@ def test_consistency( update_style=rus, update_residual_init=ruri, smooth_edge_update=True, + sequential_update=seq_upd, ) # dpa3 new impl @@ -177,6 +183,7 @@ def test_jit( nme, prec, ect, + seq_upd, ) in itertools.product( [True], # update_angle ["res_residual"], # update_style @@ -187,6 +194,7 @@ def test_jit( [1, 2], # n_multi_edge_message ["float64"], # precision [False], # use_econf_tebd + [False, True], # sequential_update ): dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) @@ -211,6 +219,7 @@ def test_jit( update_style=rus, update_residual_init=ruri, smooth_edge_update=True, + sequential_update=seq_upd, ) # dpa3 new impl