This document outlines all concrete implementations that should be created for the distributed training framework, based on industry standards and real-world scenarios.
ICommunicationBackend<T>
↓
CommunicationBackendBase<T> (abstract)
↓
├── InMemoryCommunicationBackend<T> (for testing)
├── MPICommunicationBackend<T> (MPI.NET for production)
├── NCCLCommunicationBackend<T> (NVIDIA GPUs)
└── GlooCommunicationBackend<T> (CPU-based)
IShardedModel<T, TInput, TOutput>
↓
ShardedModelBase<T, TInput, TOutput> (abstract)
↓
├── FSDPModel<T, TInput, TOutput> (Fully Sharded Data Parallel - PyTorch style)
├── ZeRO1Model<T, TInput, TOutput> (ZeRO Stage 1 - optimizer state sharding only)
├── ZeRO2Model<T, TInput, TOutput> (ZeRO Stage 2 - optimizer + gradient sharding)
├── ZeRO3Model<T, TInput, TOutput> (ZeRO Stage 3 - full parameter sharding)
├── DDPModel<T, TInput, TOutput> (Distributed Data Parallel - parameter replication)
├── PipelineParallelModel<T, TInput, TOutput> (GPipe-style pipeline parallelism)
├── TensorParallelModel<T, TInput, TOutput> (Megatron-LM style tensor parallelism)
└── HybridShardedModel<T, TInput, TOutput> (3D parallelism: data + tensor + pipeline)
IShardedOptimizer<T, TInput, TOutput>
↓
ShardedOptimizerBase<T, TInput, TOutput> (abstract)
↓
├── ZeRO1Optimizer<T, TInput, TOutput> (Shards optimizer state only)
├── ZeRO2Optimizer<T, TInput, TOutput> (Shards optimizer state + gradients)
├── ZeRO3Optimizer<T, TInput, TOutput> (Full sharding with parameter partitioning)
├── DDPOptimizer<T, TInput, TOutput> (Standard data parallel - AllReduce gradients)
├── GradientCompressionOptimizer<T, TInput, TOutput> (Compressed gradient communication)
├── AsyncSGDOptimizer<T, TInput, TOutput> (Asynchronous parameter updates)
└── ElasticOptimizer<T, TInput, TOutput> (Supports dynamic scaling of workers)
Status: ✅ Currently implemented as ShardedModel
Description: PyTorch FSDP-inspired implementation that shards model parameters, gradients, and optimizer states across all processes.
Key Features:
- Full parameter sharding across all ranks
- AllGather parameters before forward/backward pass
- AllReduce gradients after backward pass
- Minimal memory footprint per GPU
- Best for training very large models (billions of parameters)
Use Case: Training models that don't fit on a single GPU (e.g., LLMs with 7B+ parameters)
Status: ❌ To be implemented
Description: DeepSpeed ZeRO Stage 1 - only shards optimizer states, keeps parameters and gradients replicated.
Key Features:
- Parameters: Replicated across all ranks (like DDP)
- Gradients: Replicated across all ranks
- Optimizer states: Sharded across ranks (4-8x memory reduction for optimizer state)
- AllReduce for gradient synchronization
- Lower communication overhead than full sharding
Use Case: Medium-sized models where optimizer state is the memory bottleneck (e.g., Adam with 2x model size overhead)
Implementation Notes:
public class ZeRO1Model<T, TInput, TOutput> : ShardedModelBase<T, TInput, TOutput>
{
// Keep full parameters locally
private Vector<T> _fullParameters;
protected override void InitializeSharding()
{
// Don't shard parameters, keep full copy
_fullParameters = WrappedModel.GetParameters();
LocalShard = _fullParameters; // No actual sharding
}
public override void SynchronizeGradients()
{
// Standard AllReduce for gradient averaging
// Optimizer state sharding handled by ZeRO1Optimizer
}
}Status: ❌ To be implemented
Description: DeepSpeed ZeRO Stage 2 - shards optimizer states AND gradients, keeps parameters replicated.
Key Features:
- Parameters: Replicated across all ranks
- Gradients: Sharded across ranks (additional memory savings)
- Optimizer states: Sharded across ranks
- ReduceScatter for gradient sharding
- AllGather for parameter updates
- 4-8x memory reduction vs DDP
Use Case: Large models where gradient + optimizer memory is significant (e.g., models with 1B-10B parameters)
Implementation Notes:
public class ZeRO2Model<T, TInput, TOutput> : ShardedModelBase<T, TInput, TOutput>
{
private Dictionary<int, Vector<T>> _shardedGradients;
public override void SynchronizeGradients()
{
// Use ReduceScatter to shard gradients across ranks
// Each rank only keeps its shard of gradients
var fullGradients = GetGradients();
LocalShard = Config.CommunicationBackend.ReduceScatter(
fullGradients,
ReductionOperation.Average);
}
}Status: ❌ To be implemented (similar to current FSDP)
Description: DeepSpeed ZeRO Stage 3 - full sharding of parameters, gradients, and optimizer states.
Key Features:
- Parameters: Sharded across ranks, AllGather on-demand
- Gradients: Sharded across ranks
- Optimizer states: Sharded across ranks
- Maximum memory efficiency (up to 64x reduction)
- Higher communication overhead
Use Case: Extremely large models (10B-175B+ parameters) that require multi-GPU/multi-node training
Status: ❌ To be implemented
Description: Traditional DDP like PyTorch DDP - parameters replicated, gradients synchronized.
Key Features:
- Parameters: Fully replicated on each rank
- Gradients: Synchronized via AllReduce after backward pass
- Optimizer states: Fully replicated on each rank
- Lowest communication overhead
- Simple and robust
- Best for models that fit comfortably on a single GPU
Use Case: Training medium-sized models (< 1B parameters) across multiple GPUs for faster training
Implementation Notes:
public class DDPModel<T, TInput, TOutput> : ShardedModelBase<T, TInput, TOutput>
{
protected override void InitializeSharding()
{
// No sharding - each rank has full parameters
var fullParams = WrappedModel.GetParameters();
LocalShard = fullParams;
CachedFullParameters = fullParams;
}
public override Vector<T> GatherFullParameters()
{
// Already have full parameters, no gather needed
return LocalShard;
}
public override void SynchronizeGradients()
{
// AllReduce gradients to average across all ranks
var gradients = GetGradients();
Config.CommunicationBackend.AllReduce(gradients, ReductionOperation.Average);
SetGradients(gradients);
}
}Status: ❌ To be implemented
Description: GPipe-style pipeline parallelism - splits model into stages across ranks.
Key Features:
- Model layers divided into pipeline stages
- Each rank owns different layers
- Forward pass flows through pipeline
- Backward pass flows in reverse
- Micro-batching to keep all ranks busy
- Reduces memory per GPU by splitting model vertically
Use Case: Very deep models (transformers with 100+ layers) or when model architecture is easily divisible
Implementation Notes:
public class PipelineParallelModel<T, TInput, TOutput> : ShardedModelBase<T, TInput, TOutput>
{
private int _pipelineStage;
private IFullModel<T, TInput, TOutput>[] _stageModels;
public override void Train(TInput input, TOutput expectedOutput)
{
// Forward pass: send activations to next stage
// Backward pass: send gradients to previous stage
// Use micro-batching to overlap computation
}
}Status: ❌ To be implemented
Description: Megatron-LM style tensor parallelism - splits individual layers across ranks.
Key Features:
- Each layer's tensors split across ranks
- Column-wise or row-wise partitioning
- AllReduce within each layer
- Reduces memory per GPU by splitting model horizontally
- High communication overhead
Use Case: Very wide models (large transformers with huge hidden dimensions) or when activation memory is the bottleneck
Status: ❌ To be implemented
Description: Combines data parallelism, tensor parallelism, and pipeline parallelism.
Key Features:
- Data parallelism across data parallel ranks
- Tensor parallelism within each data parallel group
- Pipeline parallelism for model depth
- Maximum scalability for trillion-parameter models
- Complex but most memory efficient for extreme scale
Use Case: Training models with 100B-1T+ parameters across hundreds/thousands of GPUs
Status: ❌ To be implemented
Description: Shards optimizer states (momentum, variance buffers) across ranks.
Key Features:
- Each rank stores 1/N of optimizer states
- AllGather optimizer states when needed for updates
- 4-8x memory reduction for optimizer (especially Adam)
- Works with DDPModel or ZeRO1Model
Implementation Notes:
public class ZeRO1Optimizer<T, TInput, TOutput> : ShardedOptimizerBase<T, TInput, TOutput>
{
private Dictionary<string, Vector<T>> _shardedOptimizerStates;
protected override void UpdateOptimizerState(Vector<T> gradients)
{
// Only update my shard of optimizer state
// AllGather when needed for full parameter update
}
}Status: ❌ To be implemented
Description: Shards both gradients and optimizer states.
Key Features:
- ReduceScatter gradients to shard them
- Each rank computes optimizer update for its shard
- AllGather updated parameters
- Works with ZeRO2Model
Status: ✅ Currently implemented as ShardedOptimizer
Description: Full parameter, gradient, and optimizer state sharding.
Status: ❌ To be implemented
Description: Standard AllReduce-based gradient synchronization.
Key Features:
- AllReduce gradients after backward pass
- Each rank does identical optimizer update
- Simple and robust
- Works with DDPModel
Status: ❌ To be implemented
Description: Compresses gradients before communication.
Key Features:
- Gradient compression (quantization, sparsification, low-rank)
- Reduced communication bandwidth
- Trade-off between accuracy and speed
- Works with any distributed model
Implementation Notes:
public class GradientCompressionOptimizer<T, TInput, TOutput> : ShardedOptimizerBase<T, TInput, TOutput>
{
private IGradientCompressor<T> _compressor;
protected override void SynchronizeParameters(IFullModel<T, TInput, TOutput> model)
{
var gradients = model.GetGradients();
var compressed = _compressor.Compress(gradients);
Config.CommunicationBackend.AllReduce(compressed, ReductionOperation.Sum);
var decompressed = _compressor.Decompress(compressed);
model.SetGradients(decompressed);
}
}Status: ❌ To be implemented
Description: Asynchronous parameter updates without strict synchronization.
Key Features:
- No barriers - ranks update asynchronously
- Parameter server or peer-to-peer architecture
- Faster iteration time, but may affect convergence
- Works for large-scale training with many workers
Status: ❌ To be implemented
Description: Supports dynamic addition/removal of workers during training.
Key Features:
- Handles rank changes gracefully
- Re-shards parameters when workers join/leave
- Fault tolerance for long-running jobs
- Works with elastic training frameworks
Status: ✅ Implemented
Use Case: Testing and development without MPI
Status: ❌ To be implemented
Description: Production MPI.NET backend for CPU/GPU clusters.
Key Features:
- MPI_AllReduce, MPI_AllGather, etc.
- Works across machines (multi-node)
- Supports InfiniBand, RoCE networks
- Industry standard for HPC
Status: ❌ To be implemented
Description: NVIDIA NCCL backend for GPU-to-GPU communication.
Key Features:
- Optimized for NVIDIA GPUs
- NVLink support for intra-node
- InfiniBand/RoCE for inter-node
- Fastest for NVIDIA hardware
Status: ❌ To be implemented
Description: Facebook Gloo backend for CPU clusters.
Key Features:
- CPU-based collective operations
- TCP/IP networking
- Good for heterogeneous environments
- No MPI dependency
- ✅ InMemoryCommunicationBackend (done)
- ❌ DDPModel - Standard data parallel
- ❌ DDPOptimizer - AllReduce gradients
- ❌ MPICommunicationBackend - Production backend
- ❌ ZeRO1Model + ZeRO1Optimizer - Optimizer state sharding
- ❌ ZeRO2Model + ZeRO2Optimizer - Gradient + state sharding
- ✅ ZeRO3 (rename current ShardedModel/Optimizer to FSDPModel/FSDPOptimizer)
- ❌ PipelineParallelModel - Layer-wise parallelism
- ❌ TensorParallelModel - Tensor-wise parallelism
- ❌ HybridShardedModel - 3D parallelism
- ❌ GradientCompressionOptimizer - Reduce communication
- ❌ NCCLCommunicationBackend - GPU optimization
- ❌ AsyncSGDOptimizer - Async updates
- ❌ ElasticOptimizer - Dynamic scaling
-
Inherit from ShardedModelBase<T, TInput, TOutput>
-
Override required methods:
InitializeSharding()- How to shard/replicate parametersTrain()- Forward/backward with appropriate syncGatherFullParameters()- How to reconstruct full parametersSynchronizeGradients()- Gradient communication patternSerialize()/Deserialize()- Save/load with strategy metadata
-
Follow naming convention:
[Strategy]Model<T, TInput, TOutput> -
Add comprehensive documentation with use cases and memory/communication trade-offs
-
Include example usage in XML docs
-
Inherit from ShardedOptimizerBase<T, TInput, TOutput>
-
Override required methods:
Optimize()- Coordinate distributed optimizationSynchronizeOptimizerState()- Sync momentum/variance buffersSynchronizeParameters()- Gradient/parameter communicationShouldEarlyStop()- Consensus across ranks
-
Follow naming convention:
[Strategy]Optimizer<T, TInput, TOutput> -
Match with corresponding model (e.g., DDPOptimizer works with DDPModel)
For each implementation:
- Unit tests with InMemoryCommunicationBackend (2-4 ranks)
- Integration tests with small models
- Performance benchmarks comparing strategies
- Memory usage profiling
- Communication overhead measurements
For each implementation:
- Class documentation following project standards
- Usage examples in code examples
- Performance characteristics (memory, communication, computation)
- When to use decision guide
- Limitations and caveats
- PyTorch FSDP: https://pytorch.org/docs/stable/fsdp.html
- DeepSpeed ZeRO: https://www.deepspeed.ai/tutorials/zero/
- PyTorch DDP: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
- GPipe: https://arxiv.org/abs/1811.06965
- Megatron-LM: https://github.com/NVIDIA/Megatron-LM
- 3D Parallelism: https://arxiv.org/abs/2104.04473