-
Notifications
You must be signed in to change notification settings - Fork 608
feat(pt/dpmodel): add sequential_update for dpa3 #5355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -170,6 +170,12 @@ class RepFlowArgs: | |
| In the dynamic selection case, neighbor-scale normalization will use `e_sel / sel_reduce_factor` | ||
| or `a_sel / sel_reduce_factor` instead of the raw `e_sel` or `a_sel` values, | ||
| accommodating larger selection numbers. | ||
| sequential_update : bool, optional | ||
| Whether to use sequential update mode within each repflow layer. | ||
| When True, updates are applied sequentially: edge self → angle self (using updated edge) | ||
| → edge angle (using updated angle) → node (using final edge), | ||
| instead of the default parallel mode where all updates use original embeddings. | ||
| Currently only supports ``update_style='res_residual'``. | ||
| """ | ||
|
|
||
| def __init__( | ||
|
|
@@ -201,6 +207,7 @@ def __init__( | |
| use_exp_switch: bool = False, | ||
| use_dynamic_sel: bool = False, | ||
| sel_reduce_factor: float = 10.0, | ||
| sequential_update: bool = False, | ||
| ) -> None: | ||
| self.n_dim = n_dim | ||
| self.e_dim = e_dim | ||
|
|
@@ -231,6 +238,15 @@ def __init__( | |
| self.use_exp_switch = use_exp_switch | ||
| self.use_dynamic_sel = use_dynamic_sel | ||
| self.sel_reduce_factor = sel_reduce_factor | ||
| self.sequential_update = sequential_update | ||
| if self.sequential_update: | ||
| if self.update_style != "res_residual": | ||
| raise ValueError( | ||
| "sequential_update only supports update_style='res_residual', " | ||
| f"got '{self.update_style}'!" | ||
| ) | ||
| if not self.update_angle: | ||
| raise ValueError("sequential_update requires update_angle=True!") | ||
|
|
||
| def __getitem__(self, key: str) -> Any: | ||
| if hasattr(self, key): | ||
|
|
@@ -266,6 +282,7 @@ def serialize(self) -> dict: | |
| "use_exp_switch": self.use_exp_switch, | ||
| "use_dynamic_sel": self.use_dynamic_sel, | ||
| "sel_reduce_factor": self.sel_reduce_factor, | ||
| "sequential_update": self.sequential_update, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bump the DPA3 serialization version for the new Line 285 adds a new key to the persisted Suggested fix- "@version": 2,
+ "@version": 3,- check_version_compatibility(version, 2, 1)
+ check_version_compatibility(version, 3, 1)🤖 Prompt for AI Agents |
||
| } | ||
|
|
||
| @classmethod | ||
|
|
@@ -404,6 +421,7 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any: | |
| use_exp_switch=self.repflow_args.use_exp_switch, | ||
| use_dynamic_sel=self.repflow_args.use_dynamic_sel, | ||
| sel_reduce_factor=self.repflow_args.sel_reduce_factor, | ||
| sequential_update=self.repflow_args.sequential_update, | ||
| use_loc_mapping=use_loc_mapping, | ||
| exclude_types=exclude_types, | ||
| env_protection=env_protection, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
sequential_updateparameter docstring doesn’t mention that it requiresupdate_angle=True, butRepFlowArgs.__init__now raises a ValueError whensequential_updateis True andupdate_angleis False. Please reflect this constraint in the docstring to prevent confusing configuration errors.