-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Z1/2 init: flatten params on device #7828
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ksugama
wants to merge
21
commits into
deepspeedai:master
Choose a base branch
from
ksugama:flatten-tensor-gpu
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+155
−23
Open
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
500bd7f
Improve engine's cleanup (#7813)
tohtana 792d52c
Ignore evoformer test (#7815)
tohtana 374abb8
Fix typos in accelerator setup guide (#7818)
nathon-lee c272de4
Raise clear error on in-place GatheredParameters edits without modifi…
tohtana d2aaf0f
[Bugfix] Resolve Rank index out of range during BWD when sp_size < wo…
Flink-ddd 0bc2dd2
Update PyTorch to v2.9 for modal tests (#7816)
tohtana 4cb023c
Update version.txt to 0.18.6 after latest release (#7826)
loadams 6064c2a
Fix leaf module race condition (#7825)
tohtana e86db60
Skip sequence parallel operations during eval (#7821)
jp1924 129b42c
Support custom partitioning patterns for AutoTP (#7806)
tohtana d307396
flatten gpu side
ksugama 6eb35e8
repro script
ksugama 48ecb1d
detect gpu count in repro
ksugama 7f98cc8
add .venv to path
ksugama b3944b4
clean up
ksugama 78e58fc
format and delete repro script
ksugama 7aa7073
add dedicated test
ksugama 85d670a
parametrize tests
ksugama 60d5cb9
Fix gradient is ready with z2 (#7829)
sfc-gh-truwase 3610631
Fix AutoTP custom patterns: respect use_default_specs (#7827)
tohtana c2bb55b
Merge branch 'master' into flatten-tensor-gpu
ksugama File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # DeepSpeed Team | ||
| """ | ||
| Test that ZeRO Stage 1 and 2 use the GPU flatten path when VRAM is sufficient. | ||
| Parametrized over zero_stage (1, 2) and dtype (fp32, fp16, bf16). | ||
| """ | ||
|
|
||
| import pytest | ||
| import deepspeed | ||
| from deepspeed.accelerator import get_accelerator | ||
| from deepspeed.utils import set_log_level_from_string | ||
| from unit.common import DistributedTest | ||
| from unit.simple_model import SimpleModel | ||
|
|
||
|
|
||
| def _apply_dtype_to_config(config_dict, dtype): | ||
| """Set bf16/fp16 in config_dict based on dtype; skip if not supported.""" | ||
| if dtype == "bf16": | ||
| if not get_accelerator().is_bf16_supported(): | ||
| pytest.skip("bf16 is not supported on this accelerator") | ||
| config_dict["bf16"] = {"enabled": True} | ||
| elif dtype == "fp16": | ||
| if not get_accelerator().is_fp16_supported(): | ||
| pytest.skip("fp16 is not supported on this accelerator") | ||
| config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} | ||
| # fp32: no half-precision block | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("zero_stage", [1, 2]) | ||
| @pytest.mark.parametrize("dtype", ["fp32", "fp16", "bf16"], ids=["fp32", "fp16", "bf16"]) | ||
| class TestStage2FlattenOnGPU(DistributedTest): | ||
| """ZeRO-1 and ZeRO-2 with small model should flatten on GPU (sufficient VRAM).""" | ||
|
|
||
| world_size = 2 # Run on 2 GPUs when available | ||
|
|
||
| def test_flatten_on_gpu_path_taken(self, monkeypatch, zero_stage, dtype): | ||
| """Assert the GPU flatten path was used (not CPU flatten + move).""" | ||
| if not get_accelerator().is_available(): | ||
| pytest.skip("Accelerator not available") | ||
| config_dict = { | ||
| "train_micro_batch_size_per_gpu": 2, | ||
| "gradient_accumulation_steps": 1, | ||
| "zero_optimization": { | ||
| "stage": zero_stage | ||
| }, | ||
| "optimizer": { | ||
| "type": "Adam", | ||
| "params": { | ||
| "lr": 1e-3 | ||
| } | ||
| }, | ||
| } | ||
| _apply_dtype_to_config(config_dict, dtype) | ||
|
|
||
| set_log_level_from_string("info") | ||
| log_messages = [] | ||
|
|
||
| def mock_logger_info(msg, *args, **kwargs): | ||
| log_messages.append(msg if isinstance(msg, str) else str(msg)) | ||
|
|
||
| monkeypatch.setattr("deepspeed.utils.logger.info", mock_logger_info) | ||
|
|
||
| hidden_dim = 64 | ||
| model = SimpleModel(hidden_dim=hidden_dim, nlayers=2) | ||
| deepspeed.initialize( | ||
| config=config_dict, | ||
| model=model, | ||
| model_parameters=model.parameters(), | ||
| ) | ||
|
|
||
| # Small model + no CPU offload => GPU path; that path logs "on GPU" | ||
| gpu_path_logs = [m for m in log_messages if "Flattening param group" in m and "on GPU" in m] | ||
| assert gpu_path_logs, ( | ||
| f"Expected GPU flatten path (logger.info should be called with 'Flattening param group' and 'on GPU'). " | ||
| f"Captured messages: {log_messages}") | ||
|
|
||
| def test_flat_buffers_on_accelerator(self, zero_stage, dtype): | ||
| """Regression: flat buffers must end up on the accelerator (not left on CPU).""" | ||
| if not get_accelerator().is_available(): | ||
| pytest.skip("Accelerator not available") | ||
| config_dict = { | ||
| "train_micro_batch_size_per_gpu": 2, | ||
| "gradient_accumulation_steps": 1, | ||
| "zero_optimization": { | ||
| "stage": zero_stage | ||
| }, | ||
| "optimizer": { | ||
| "type": "Adam", | ||
| "params": { | ||
| "lr": 1e-3 | ||
| } | ||
| }, | ||
| } | ||
| _apply_dtype_to_config(config_dict, dtype) | ||
|
|
||
| hidden_dim = 64 | ||
| model = SimpleModel(hidden_dim=hidden_dim, nlayers=2) | ||
| engine, _, _, _ = deepspeed.initialize( | ||
| config=config_dict, | ||
| model=model, | ||
| model_parameters=model.parameters(), | ||
| ) | ||
| opt = engine.optimizer | ||
| assert hasattr(opt, "bit16_groups_flat"), "ZeRO-1/2 optimizer should have bit16_groups_flat" | ||
| device_type = get_accelerator().device_name() | ||
| for i, flat in enumerate(opt.bit16_groups_flat): | ||
| assert flat.device.type == device_type, (f"Flat buffer {i} must be on {device_type}, got {flat.device}") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stas00 I believe
available_memory()eventually calls into nvidia-smi. Is this the foot gun you were warning about?If it is, maybe this should be fixed in a different PR since that problem touches more than is related to these changes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remind me what issue I flagged that you are referring to?
available_memoryis here:DeepSpeed/accelerator/cuda_accelerator.py
Lines 186 to 194 in a44fb58