-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Feature/group offload pinning #12747
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4fc12e2
93e6d31
3455019
9c3c14f
3b3813d
b9e0994
a99755a
ffad316
33d8b52
ed8a97a
de38128
c72ddbc
1cd3355
1194a83
3ef894d
7a2f3f0
005e51b
1bd4539
c82820e
93c253f
8d059e6
b950c74
8da39a3
6c5e41a
d08d988
61b3662
0cbd079
53659d8
6a98592
af61b9c
2e8f538
335dca8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |
| from contextlib import contextmanager, nullcontext | ||
| from dataclasses import dataclass, replace | ||
| from enum import Enum | ||
| from typing import Dict, List, Optional, Set, Tuple, Union | ||
| from typing import Callable, Dict, List, Optional, Set, Tuple, Union | ||
|
|
||
| import safetensors.torch | ||
| import torch | ||
|
|
@@ -27,6 +27,9 @@ | |
| from .hooks import HookRegistry, ModelHook | ||
|
|
||
|
|
||
| VALID_PIN_GROUPS = {"all", "first_last"} | ||
|
|
||
|
|
||
| if is_accelerate_available(): | ||
| from accelerate.hooks import AlignDevicesHook, CpuOffload | ||
| from accelerate.utils import send_to_device | ||
|
|
@@ -62,6 +65,7 @@ class GroupOffloadingConfig: | |
| block_modules: Optional[List[str]] = None | ||
| exclude_kwargs: Optional[List[str]] = None | ||
| module_prefix: Optional[str] = "" | ||
| pin_groups: Optional[Union[str, Callable]] = None | ||
|
|
||
|
|
||
| class ModuleGroup: | ||
|
|
@@ -94,6 +98,7 @@ def __init__( | |
| self.record_stream = record_stream | ||
| self.onload_self = onload_self | ||
| self.low_cpu_mem_usage = low_cpu_mem_usage | ||
| self.pinned = False | ||
|
|
||
| self.offload_to_disk_path = offload_to_disk_path | ||
| self._is_offloaded_to_disk = False | ||
|
|
@@ -156,7 +161,7 @@ def _pinned_memory_tensors(self): | |
| finally: | ||
| pinned_dict = None | ||
|
|
||
| def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream): | ||
| def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream=None): | ||
| tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) | ||
| if self.record_stream: | ||
| tensor.data.record_stream(default_stream) | ||
|
|
@@ -212,7 +217,6 @@ def _onload_from_memory(self): | |
|
|
||
| context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream) | ||
| default_stream = self._torch_accelerator_module.current_stream() if self.stream is not None else None | ||
|
|
||
| with context: | ||
| if self.stream is not None: | ||
| with self._pinned_memory_tensors() as pinned_memory: | ||
|
|
@@ -291,7 +295,8 @@ def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None | |
| self.config = config | ||
|
|
||
| def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: | ||
| if self.group.offload_leader == module: | ||
| # For disk offload we materialize the safetensor files upfront so callers can inspect them immediately. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't seem related to the PR.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didnot understand.. We dropped the eager offload so adapters can load before anything leaves the device. The only eager path left is when |
||
| if self.group.offload_to_disk_path is not None and self.group.offload_leader == module: | ||
| self.group.offload_() | ||
| return module | ||
|
|
||
|
|
@@ -300,35 +305,48 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs): | |
| # method is the onload_leader of the group. | ||
| if self.group.onload_leader is None: | ||
| self.group.onload_leader = module | ||
| is_leader = self.group.onload_leader == module | ||
| should_onload_next_group = self.next_group is not None and not self.next_group.onload_self | ||
| should_orchestrate = self.group.pinned or is_leader | ||
|
|
||
| if should_orchestrate: | ||
| # Pinned groups keep their params on the onload device; orchestrate onload/prefetch/sync every call. | ||
| if self.group.pinned: | ||
| if is_leader and not self._is_group_on_device(): | ||
| self.group.onload_() | ||
| else: | ||
| if is_leader and self.group.onload_self: | ||
| self.group.onload_() | ||
|
|
||
| # If the current module is the onload_leader of the group, we onload the group if it is supposed | ||
| # to onload itself. In the case of using prefetching with streams, we onload the next group if | ||
| # it is not supposed to onload itself. | ||
| if self.group.onload_leader == module: | ||
| if self.group.onload_self: | ||
| self.group.onload_() | ||
|
|
||
| should_onload_next_group = self.next_group is not None and not self.next_group.onload_self | ||
| if should_onload_next_group: | ||
| self.next_group.onload_() | ||
|
|
||
| should_synchronize = ( | ||
| not self.group.onload_self and self.group.stream is not None and not should_onload_next_group | ||
| not self.group.onload_self | ||
| and self.group.stream is not None | ||
| and not should_onload_next_group | ||
| and not self.group.record_stream | ||
| ) | ||
| if should_synchronize: | ||
| # If this group didn't onload itself, it means it was asynchronously onloaded by the | ||
| # previous group. We need to synchronize the side stream to ensure parameters | ||
| # are completely loaded to proceed with forward pass. Without this, uninitialized | ||
| # weights will be used in the computation, leading to incorrect results | ||
| # Also, we should only do this synchronization if we don't already do it from the sync call in | ||
| # self.next_group.onload_, hence the `not should_onload_next_group` check. | ||
| # weights will be used in the computation, leading to incorrect results. | ||
| self.group.stream.synchronize() | ||
|
|
||
| args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking) | ||
| kwargs = self._send_kwargs_to_device(kwargs) | ||
| return args, kwargs | ||
|
|
||
| def post_forward(self, module: torch.nn.Module, output): | ||
| if self.group.pinned: | ||
| return output | ||
|
|
||
| if self.group.offload_leader == module: | ||
| self.group.offload_() | ||
| return output | ||
|
|
||
| # Some Autoencoder models use a feature cache that is passed through submodules | ||
| # and modified in place. The `send_to_device` call returns a copy of this feature cache object | ||
| # which breaks the inplace updates. Use `exclude_kwargs` to mark these cache features | ||
| def _send_kwargs_to_device(self, kwargs): | ||
| exclude_kwargs = self.config.exclude_kwargs or [] | ||
| if exclude_kwargs: | ||
| moved_kwargs = send_to_device( | ||
|
|
@@ -337,15 +355,21 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs): | |
| non_blocking=self.group.non_blocking, | ||
| ) | ||
| kwargs.update(moved_kwargs) | ||
| else: | ||
| kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking) | ||
| return kwargs | ||
| return send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking) | ||
|
|
||
| return args, kwargs | ||
| def _is_group_on_device(self) -> bool: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have erased the duplicate method names for _is_group_on_device |
||
| tensors = [] | ||
| for group_module in self.group.modules: | ||
| tensors.extend(list(group_module.parameters())) | ||
| tensors.extend(list(group_module.buffers())) | ||
| tensors.extend(self.group.parameters) | ||
| tensors.extend(self.group.buffers) | ||
|
|
||
| def post_forward(self, module: torch.nn.Module, output): | ||
| if self.group.offload_leader == module: | ||
| self.group.offload_() | ||
| return output | ||
| if len(tensors) == 0: | ||
| return True | ||
|
|
||
| return all(t.device == self.group.onload_device for t in tensors) | ||
|
|
||
|
|
||
| class LazyPrefetchGroupOffloadingHook(ModelHook): | ||
|
|
@@ -358,9 +382,10 @@ class LazyPrefetchGroupOffloadingHook(ModelHook): | |
|
|
||
| _is_stateful = False | ||
|
|
||
| def __init__(self): | ||
| def __init__(self, pin_groups: Optional[Union[str, Callable]] = None): | ||
| self.execution_order: List[Tuple[str, torch.nn.Module]] = [] | ||
| self._layer_execution_tracker_module_names = set() | ||
| self.pin_groups = pin_groups | ||
|
|
||
| def initialize_hook(self, module): | ||
| def make_execution_order_update_callback(current_name, current_submodule): | ||
|
|
@@ -442,6 +467,50 @@ def post_forward(self, module, output): | |
| group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group | ||
| group_offloading_hooks[i].next_group.onload_self = False | ||
|
|
||
| if self.pin_groups is not None and num_executed > 0: | ||
| param_exec_info = [] | ||
| for idx, ((name, submodule), hook) in enumerate(zip(self.execution_order, group_offloading_hooks)): | ||
| if hook is None: | ||
| continue | ||
| if next(submodule.parameters(), None) is None and next(submodule.buffers(), None) is None: | ||
| continue | ||
| param_exec_info.append((name, submodule, hook)) | ||
|
|
||
| num_param_modules = len(param_exec_info) | ||
| if num_param_modules > 0: | ||
| pinned_indices = set() | ||
| if isinstance(self.pin_groups, str): | ||
| if self.pin_groups == "all": | ||
| pinned_indices = set(range(num_param_modules)) | ||
| elif self.pin_groups == "first_last": | ||
| pinned_indices.add(0) | ||
| pinned_indices.add(num_param_modules - 1) | ||
| elif callable(self.pin_groups): | ||
| for idx, (name, submodule, _) in enumerate(param_exec_info): | ||
| should_pin = False | ||
| try: | ||
| should_pin = bool(self.pin_groups(submodule)) | ||
| except TypeError: | ||
| try: | ||
| should_pin = bool(self.pin_groups(name, submodule)) | ||
| except TypeError: | ||
| should_pin = bool(self.pin_groups(name, submodule, idx)) | ||
| if should_pin: | ||
| pinned_indices.add(idx) | ||
|
|
||
| pinned_groups = set() | ||
| for idx in pinned_indices: | ||
| if idx >= num_param_modules: | ||
| continue | ||
| group = param_exec_info[idx][2].group | ||
| if group not in pinned_groups: | ||
| group.pinned = True | ||
| pinned_groups.add(group) | ||
|
|
||
| for group in pinned_groups: | ||
| if group.offload_device != group.onload_device: | ||
| group.onload_() | ||
|
|
||
| return output | ||
|
|
||
|
|
||
|
|
@@ -461,6 +530,16 @@ def pre_forward(self, module, *args, **kwargs): | |
| return args, kwargs | ||
|
|
||
|
|
||
| def _validate_pin_groups(pin_groups: Optional[Union[str, Callable]]) -> Optional[Union[str, Callable]]: | ||
| if pin_groups is None or callable(pin_groups): | ||
| return pin_groups | ||
| if isinstance(pin_groups, str) and pin_groups in VALID_PIN_GROUPS: | ||
| return pin_groups | ||
| raise ValueError( | ||
| f"`pin_groups` must be None, {', '.join(repr(v) for v in sorted(VALID_PIN_GROUPS))}, or a callable." | ||
| ) | ||
|
|
||
|
|
||
| def apply_group_offloading( | ||
| module: torch.nn.Module, | ||
| onload_device: Union[str, torch.device], | ||
|
|
@@ -474,6 +553,7 @@ def apply_group_offloading( | |
| offload_to_disk_path: Optional[str] = None, | ||
| block_modules: Optional[List[str]] = None, | ||
| exclude_kwargs: Optional[List[str]] = None, | ||
| pin_groups: Optional[Union[str, Callable]] = None, | ||
| ) -> None: | ||
| r""" | ||
| Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and | ||
|
|
@@ -535,9 +615,13 @@ def apply_group_offloading( | |
| List of module names that should be treated as blocks for offloading. If provided, only these modules will | ||
| be considered for block-level offloading. If not provided, the default block detection logic will be used. | ||
| exclude_kwargs (`List[str]`, *optional*): | ||
| List of kwarg keys that should not be processed by send_to_device. This is useful for mutable state like | ||
| List of kwarg keys that should not be processed by `send_to_device`. This is useful for mutable state like | ||
| caching lists that need to maintain their object identity across forward passes. If not provided, will be | ||
| inferred from the module's `_skip_keys` attribute if it exists. | ||
| pin_groups (`"first_last"` or `"all"` or `Callable`, *optional*, defaults to `None`): | ||
| Optionally keeps selected groups on the onload device permanently. Use `"first_last"` to pin the first and | ||
| last parameter-bearing groups, `"all"` to pin every parameter-bearing group, or pass a callable that | ||
| receives a module (and optionally the module name and index) and returns `True` to pin that group. | ||
|
|
||
| Example: | ||
| ```python | ||
|
|
@@ -577,6 +661,7 @@ def apply_group_offloading( | |
| if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None: | ||
| raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.") | ||
|
|
||
| pin_groups = _validate_pin_groups(pin_groups) | ||
| _raise_error_if_accelerate_model_or_sequential_hook_present(module) | ||
|
|
||
| if block_modules is None: | ||
|
|
@@ -597,6 +682,8 @@ def apply_group_offloading( | |
| offload_to_disk_path=offload_to_disk_path, | ||
| block_modules=block_modules, | ||
| exclude_kwargs=exclude_kwargs, | ||
| module_prefix="", | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we have to default to this?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it is just an internal label to avoid name collisions when we recurse into explicit block_modules. At the top level we leave it empty so ids stay the same, when we go into a child we prefix its name so two submodules with the same class name don’t clash on group id. It is not exposed to users. |
||
| pin_groups=pin_groups, | ||
| ) | ||
| _apply_group_offloading(module, config) | ||
|
|
||
|
|
@@ -617,7 +704,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf | |
| done at the top-level blocks and modules specified in block_modules. | ||
|
|
||
| When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified | ||
| module, recursively apply block offloading to it. | ||
| module, we either offload the entire submodule or recursively apply block offloading to it. | ||
| """ | ||
| if config.stream is not None and config.num_blocks_per_group != 1: | ||
| logger.warning( | ||
|
|
@@ -634,7 +721,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf | |
|
|
||
| for name, submodule in module.named_children(): | ||
| # Check if this is an explicitly defined block module | ||
| if name in block_modules: | ||
| if block_modules and name in block_modules: | ||
| # Track submodule using a prefix to avoid filename collisions during disk offload. | ||
| # Without this, submodules sharing the same model class would be assigned identical | ||
| # filenames (derived from the class name). | ||
|
|
@@ -643,7 +730,6 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf | |
|
|
||
| _apply_group_offloading_block_level(submodule, submodule_config) | ||
| modules_with_group_offloading.add(name) | ||
|
|
||
| elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): | ||
| # Handle ModuleList and Sequential blocks as before | ||
| for i in range(0, len(submodule), config.num_blocks_per_group): | ||
|
|
@@ -672,6 +758,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf | |
| else: | ||
| # This is an unmatched module | ||
| unmatched_modules.append((name, submodule)) | ||
| modules_with_group_offloading.add(name) | ||
|
|
||
| # Apply group offloading hooks to the module groups | ||
| for i, group in enumerate(matched_module_groups): | ||
|
|
@@ -709,6 +796,25 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf | |
| _apply_group_offloading_hook(module, unmatched_group, config=config) | ||
| else: | ||
| _apply_lazy_group_offloading_hook(module, unmatched_group, config=config) | ||
| elif config.stream is None and config.offload_to_disk_path is None: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems unnecessary. Explain?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. originally added the empty root hook to tag the top module as offloaded when everything else was matched, but it did not change behaviour, the child hooks already mark the model as group-offloaded and the guardrails rely on those. It just added an empty group and potential extra files, so have removed it to simplify. Functionally nothing depends on it. |
||
| # Ensure the top-level module always has a hook when no unmatched modules/params/buffers, | ||
| # to satisfy hook presence checks in tests. Using an empty group avoids extra offload files. | ||
| empty_group = ModuleGroup( | ||
| modules=[], | ||
| offload_device=config.offload_device, | ||
| onload_device=config.onload_device, | ||
| offload_to_disk_path=None, | ||
| offload_leader=module, | ||
| onload_leader=module, | ||
| parameters=[], | ||
| buffers=[], | ||
| non_blocking=False, | ||
| stream=None, | ||
| record_stream=False, | ||
| onload_self=True, | ||
| group_id=f"{config.module_prefix}{module.__class__.__name__}_empty_group", | ||
| ) | ||
| _apply_group_offloading_hook(module, empty_group, config=config) | ||
|
|
||
|
|
||
| def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: | ||
|
|
@@ -735,7 +841,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff | |
| record_stream=config.record_stream, | ||
| low_cpu_mem_usage=config.low_cpu_mem_usage, | ||
| onload_self=True, | ||
| group_id=name, | ||
| group_id=f"{config.module_prefix}{name}", | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's happening here?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is the same thing as above, we prefix |
||
| ) | ||
| _apply_group_offloading_hook(submodule, group, config=config) | ||
| modules_with_group_offloading.add(name) | ||
|
|
@@ -782,10 +888,32 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff | |
| record_stream=config.record_stream, | ||
| low_cpu_mem_usage=config.low_cpu_mem_usage, | ||
| onload_self=True, | ||
| group_id=name, | ||
| group_id=f"{config.module_prefix}{name}", | ||
| ) | ||
| _apply_group_offloading_hook(parent_module, group, config=config) | ||
|
|
||
| # Ensure the top-level module also has a group_offloading hook so hook presence checks pass, | ||
| # even when it holds no parameters/buffers itself. | ||
| if config.stream is None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. even when all real groups sit in child modules, the root needs a |
||
| root_registry = HookRegistry.check_if_exists_or_initialize(module) | ||
| if root_registry.get_hook(_GROUP_OFFLOADING) is None: | ||
| empty_group = ModuleGroup( | ||
| modules=[], | ||
| offload_device=config.offload_device, | ||
| onload_device=config.onload_device, | ||
| offload_to_disk_path=None, | ||
| offload_leader=module, | ||
| onload_leader=module, | ||
| parameters=[], | ||
| buffers=[], | ||
| non_blocking=False, | ||
| stream=None, | ||
| record_stream=False, | ||
| onload_self=True, | ||
| group_id=f"{config.module_prefix}{module.__class__.__name__}_empty_group", | ||
| ) | ||
| root_registry.register_hook(GroupOffloadingHook(empty_group, config=config), _GROUP_OFFLOADING) | ||
|
|
||
| if config.stream is not None: | ||
| # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer | ||
| # and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the | ||
|
|
@@ -838,7 +966,7 @@ def _apply_lazy_group_offloading_hook( | |
| hook = GroupOffloadingHook(group, config=config) | ||
| registry.register_hook(hook, _GROUP_OFFLOADING) | ||
|
|
||
| lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook() | ||
| lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook(pin_groups=config.pin_groups) | ||
| registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING) | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -966,6 +966,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo | |||||
| # keys toignore when AlignDeviceHook moves inputs/outputs between devices | ||||||
| # these are shared mutable state modified in-place | ||||||
| _skip_keys = ["feat_cache", "feat_idx"] | ||||||
| _group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"] | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should also be made a part of
We have it for
Preferably initialized to a reasonable default like
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have added |
||||||
|
|
||||||
| @register_to_config | ||||||
| def __init__( | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have to set the default of
default_stream?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made it optional because the non-stream path calls _process_tensors_from_modules without a stream, there is nothing to record in that case, and record_stream is gated. None is a safety net for the record call, and it saves passing a placeholder from those call sites. If you prefer the stricter signature, I can keep it required and pass None explicitly where we don’t use streams. please do correct me thru my understanding if this is required to change