diff --git a/README.md b/README.md index f8e6763b4..71595e11b 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ The MLCommons™ **AlgoPerf: Training Algorithms benchmark** is designed to find When training neural nets, practitioners face many critical yet often opaque decisions: What optimizer to choose? How should its learning rate be tuned? What learning rate schedule should be used? These choices can make or break training, yet the community has lacked a clear, standardized way to identify the state of the art. Unlike benchmarks focused on hardware or model architecture, AlgoPerf isolates the **training algorithm** itself, which includes the optimizer, regularization, data selection, and hyperparameters like the learning rate schedule. By standardizing the benchmark process, AlgoPerf offers a meaningful apples-to-apples comparison of training algorithms and follows the following **key principles**: -- 🎯 **Fixed Target, Model & Hardware:** Submitted training algorithms must train a set of [**fixed models**](/docs/DOCUMENTATION.md#workloads) to a pre-defined validation performance target as fast as possible. All submissions use the same model architecture and are run on the same [**standardized hardware**](/docs/DOCUMENTATION.md#benchmarking-hardware) (8x NVIDIA V100 GPUs). This isolates the training algorithm's performance and allows a fair apples-to-apples comparison. +- 🎯 **Fixed Target, Model & Hardware:** Submitted training algorithms must train a set of [**fixed models**](/docs/DOCUMENTATION.md#workloads) to a pre-defined validation performance target as fast as possible. All submissions use the same model architecture and are run on the same [**standardized hardware**](/docs/DOCUMENTATION.md#benchmarking-hardware) (4x A100 (40GB) GPUs). This isolates the training algorithm's performance and allows a fair apples-to-apples comparison. - ⏱️ **Time-To-Result:** Submissions are evaluated based on the total wall-clock time required to reach the target, rewarding practical and efficient algorithms. - 🧠 **Diverse Workloads:** The benchmark includes [**8 diverse deep learning workloads**](/docs/DOCUMENTATION.md#workloads) across domains like image classification, speech recognition, and machine translation. A submission's score is computed by aggregating its performance, using [**performance profiles**](/docs/DOCUMENTATION.md#benchmark-score-using-performance-profiles), across all workloads to ensure general-purpose algorithms. - 📦 **Fully-Specified Algorithms:** Submissions must be complete procedures and thus hyperparameter tuning is treated as part of the algorithm. Submissions can either provide a search space for automated tuning ([**External tuning ruleset**](/docs/DOCUMENTATION.md#external-tuning-ruleset)) or be hyperparameter-free ([**Self-tuning ruleset**](/docs/DOCUMENTATION.md#self-tuning-ruleset)) with any tuning done automatically and "on the clock". This measures an algorithm's _total_ practical cost and provides practitioners with a complete method, eliminating the guesswork of how to apply it. diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index af09e67fc..937001b87 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -20,6 +20,7 @@ def pytorch_setup() -> Tuple[bool, int, torch.device, int]: + torch.set_float32_matmul_precision('high') use_pytorch_ddp = 'LOCAL_RANK' in os.environ rank = int(os.environ['LOCAL_RANK']) if use_pytorch_ddp else 0 device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu') diff --git a/algoperf/workloads/cifar/cifar_pytorch/workload.py b/algoperf/workloads/cifar/cifar_pytorch/workload.py index a6e8569cc..f053fd828 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/workload.py +++ b/algoperf/workloads/cifar/cifar_pytorch/workload.py @@ -110,12 +110,12 @@ def _build_dataset( batch_size=ds_iter_batch_size, shuffle=not USE_PYTORCH_DDP and is_train, sampler=sampler, - num_workers=4 if is_train else self.eval_num_workers, + num_workers=2 * N_GPUS if is_train else self.eval_num_workers, pin_memory=True, drop_last=is_train, ) - dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE) dataloader = data_utils.cycle(dataloader, custom_sampler=USE_PYTORCH_DDP) + dataloader = data_utils.dataloader_iterator_wrapper(dataloader, DEVICE) return dataloader def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: diff --git a/algoperf/workloads/criteo1tb/workload.py b/algoperf/workloads/criteo1tb/workload.py index 2cb7e5450..4d2196cd5 100644 --- a/algoperf/workloads/criteo1tb/workload.py +++ b/algoperf/workloads/criteo1tb/workload.py @@ -95,11 +95,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 7_703 # ~2.1 hours. + return 8_915 # ~2.4 hours. @property def eval_period_time_sec(self) -> int: - return 2 * 60 # 2 mins. + return 356 # approx 25 evals def _build_input_queue( self, diff --git a/algoperf/workloads/fastmri/workload.py b/algoperf/workloads/fastmri/workload.py index 0b1ecfaa1..b87dfc755 100644 --- a/algoperf/workloads/fastmri/workload.py +++ b/algoperf/workloads/fastmri/workload.py @@ -95,11 +95,11 @@ def accelerations(self): @property def max_allowed_runtime_sec(self) -> int: - return 4_430 # ~1.2 hours + return 2_745 # ~0.7 hours @property def eval_period_time_sec(self) -> int: - return 80 + return 110 # approx 25 evals @property def step_hint(self) -> int: diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index d5366c60d..289136bfb 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -3,10 +3,13 @@ import contextlib import functools import itertools +import json import math import os import random -from typing import Dict, Iterator, Optional, Tuple +import time +from pathlib import Path +from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union import numpy as np import torch @@ -14,7 +17,11 @@ import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP from torchvision import transforms -from torchvision.datasets.folder import ImageFolder +from torchvision.datasets.folder import ( + IMG_EXTENSIONS, + ImageFolder, + default_loader, +) import algoperf.random_utils as prng from algoperf import data_utils, param_utils, pytorch_utils, spec @@ -28,6 +35,100 @@ USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() +class CachedImageFolder(ImageFolder): + """ImageFolder that caches the file listing to avoid repeated filesystem scans.""" + + def __init__( + self, + root: Union[str, Path], + cache_file: Optional[Union[str, Path]] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + loader: Callable[[str], Any] = default_loader, + is_valid_file: Optional[Callable[[str], bool]] = None, + allow_empty: bool = False, + rebuild_cache: bool = False, + cache_build_timeout_minutes: int = 30, + ): + self.root = os.path.abspath(root) + self.transform = transform + self.target_transform = target_transform + self.loader = loader + self.extensions = IMG_EXTENSIONS if is_valid_file is None else None + + # Default cache location: .cache_index.json in the root directory + if cache_file is None: + cache_file = os.path.join(self.root, '.cache_index.json') + self.cache_file = cache_file + + is_distributed = dist.is_available() and dist.is_initialized() + rank = dist.get_rank() if is_distributed else 0 + + cache_exists = os.path.exists(self.cache_file) + needs_rebuild = rebuild_cache or not cache_exists + + if needs_rebuild: + # We only want one process to build the cache + # and others to wait for it to finish. + if rank == 0: + self._build_and_save_cache(is_valid_file, allow_empty) + if is_distributed: + self._wait_for_cache(timeout_minutes=cache_build_timeout_minutes) + dist.barrier() + + self._load_from_cache() + + self.targets = [s[1] for s in self.samples] + self.imgs = self.samples + + def _wait_for_cache(self, timeout_minutes: int): + """Poll for cache file to exist.""" + timeout_seconds = timeout_minutes * 60 + poll_interval = 5 + elapsed = 0 + + while not os.path.exists(self.cache_file): + if elapsed >= timeout_seconds: + raise TimeoutError( + f'Timed out waiting for cache file after {timeout_minutes} minutes: {self.cache_file}' + ) + time.sleep(poll_interval) + elapsed += poll_interval + + def _load_from_cache(self): + """Load classes and samples from cache file.""" + with open(os.path.abspath(self.cache_file), 'r') as f: + cache = json.load(f) + self.classes = cache['classes'] + self.class_to_idx = cache['class_to_idx'] + # Convert relative paths back to absolute + self.samples = [ + (os.path.join(self.root, rel_path), idx) + for rel_path, idx in cache['samples'] + ] + + def _build_and_save_cache(self, is_valid_file, allow_empty): + """Scan filesystem, build index, and save to cache.""" + self.classes, self.class_to_idx = self.find_classes(self.root) + self.samples = self.make_dataset( + self.root, + class_to_idx=self.class_to_idx, + extensions=self.extensions, + is_valid_file=is_valid_file, + allow_empty=allow_empty, + ) + + cache = { + 'classes': self.classes, + 'class_to_idx': self.class_to_idx, + 'samples': [ + (os.path.relpath(path, self.root), idx) for path, idx in self.samples + ], + } + with open(os.path.abspath(self.cache_file), 'w') as f: + json.dump(cache, f) + + def imagenet_v2_to_torch( batch: Dict[str, spec.Tensor], ) -> Dict[str, spec.Tensor]: @@ -119,8 +220,10 @@ def _build_dataset( ) folder = 'train' if 'train' in split else 'val' - dataset = ImageFolder( - os.path.join(data_dir, folder), transform=transform_config + dataset = CachedImageFolder( + os.path.join(data_dir, folder), + transform=transform_config, + cache_file='.imagenet_{}_cache_index.json'.format(split), ) if split == 'eval_train': @@ -145,16 +248,16 @@ def _build_dataset( sampler = data_utils.DistributedEvalSampler( dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False ) - dataloader = torch.utils.data.DataLoader( dataset, batch_size=ds_iter_batch_size, shuffle=not USE_PYTORCH_DDP and is_train, sampler=sampler, - num_workers=4 if is_train else self.eval_num_workers, + num_workers=5 * N_GPUS if is_train else self.eval_num_workers, pin_memory=True, drop_last=is_train, persistent_workers=is_train, + prefetch_factor=N_GPUS, ) dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE) dataloader = data_utils.cycle( @@ -163,7 +266,6 @@ def _build_dataset( use_mixup=use_mixup, mixup_alpha=0.2, ) - return dataloader def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: diff --git a/algoperf/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py index ef696e328..de8458c92 100644 --- a/algoperf/workloads/imagenet_resnet/workload.py +++ b/algoperf/workloads/imagenet_resnet/workload.py @@ -103,11 +103,11 @@ def resize_size(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 66_159 # ~18.4 hours + return 49_918 # ~13.8 hours @property def eval_period_time_sec(self) -> int: - return 510 # 8.5 minutes. + return 1_996 # approx 25 evals def _build_dataset( self, diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py index fc2a3cd46..06df7ea75 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -5,7 +5,6 @@ and https://github.com/lucidrains/vit-pytorch. """ -import math from typing import Any, Optional, Tuple, Union import torch @@ -126,13 +125,14 @@ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: value_layer = self.transpose_for_scores(self.value(x)) query_layer = self.transpose_for_scores(mixed_query_layer) - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - attention_scores = attention_scores / math.sqrt(self.head_dim) - - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = F.dropout(attention_probs, dropout_rate, self.training) + # Use built-in scaled_dot_product_attention (Flash Attention when available) + context_layer = F.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + dropout_p=dropout_rate if self.training else 0.0, + ) - context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_dim,) context_layer = context_layer.view(new_context_layer_shape) diff --git a/algoperf/workloads/imagenet_vit/workload.py b/algoperf/workloads/imagenet_vit/workload.py index 2a0070ba4..4da02614f 100644 --- a/algoperf/workloads/imagenet_vit/workload.py +++ b/algoperf/workloads/imagenet_vit/workload.py @@ -88,11 +88,11 @@ def eval_batch_size(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 69_768 # ~19.4 hours + return 64_292 # ~17.8 hours @property def eval_period_time_sec(self) -> int: - return 7 * 60 # 7 mins. + return 2_571 # 7 mins. def _build_dataset( self, diff --git a/algoperf/workloads/librispeech_conformer/workload.py b/algoperf/workloads/librispeech_conformer/workload.py index 791270719..5a0a546e4 100644 --- a/algoperf/workloads/librispeech_conformer/workload.py +++ b/algoperf/workloads/librispeech_conformer/workload.py @@ -80,11 +80,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 58_015 # ~16.1 hours + return 43_680 # ~16.1 hours @property def eval_period_time_sec(self) -> int: - return 24 * 60 + return 1747 # approx 25 evals @property def step_hint(self) -> int: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 3a320b0dd..2a8fd29d0 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -100,7 +100,11 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 44_405 # ~12.3 hours + return 36_949 # ~12.3 hours + + @property + def eval_period_time_sec(self) -> int: + return 1447 # approx 25 evals @property def use_tanh(self) -> bool: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 672f3440f..c6bb149f7 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -96,7 +96,11 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 44_405 # ~12.3 hours + return 36_949 # 10.3 hours + + @property + def eval_period_time_sec(self) -> int: + return 1447 # approx 25 evals @property def use_tanh(self) -> bool: diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 8717e46d6..771b103a0 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -88,11 +88,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 12_011 # ~3.3 hours + return 11_303 # ~3.1 hours @property def eval_period_time_sec(self) -> int: - return 4 * 60 + return 452 # approx 25 evals def _build_input_queue( self, diff --git a/algoperf/workloads/wmt/workload.py b/algoperf/workloads/wmt/workload.py index 40e4262dd..2e232214e 100644 --- a/algoperf/workloads/wmt/workload.py +++ b/algoperf/workloads/wmt/workload.py @@ -89,11 +89,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 43_336 # ~12.0 hours + return 16_114 # ~12.0 hours @property def eval_period_time_sec(self) -> int: - return 14 * 60 + return 644 @property def step_hint(self) -> int: diff --git a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py index 0577cd4e0..cf431de24 100644 --- a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py @@ -340,12 +340,6 @@ def update_params( dropout_rate, ) ) - - # Log loss, grad_norm. - if global_step % 100 == 0 and workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - {'loss': loss.item(), 'grad_norm': grad_norm.item()}, global_step - ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state diff --git a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py index 0b32199ba..494ada4c8 100644 --- a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -5,7 +5,6 @@ import torch import torch.distributed.nn as dist_nn -from absl import logging from torch import Tensor from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR @@ -300,28 +299,6 @@ def update_params( optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() - # Log training metrics - loss, grad_norm, batch_size. - if global_step <= 100 or global_step % 500 == 0: - with torch.no_grad(): - parameters = [p for p in current_model.parameters() if p.grad is not None] - grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 - ) - if workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, - global_step, - ) - logging.info( - '%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item(), - ) - return (optimizer_state, current_param_container, new_model_state) diff --git a/docker/build_docker_images.sh b/docker/build_docker_images.sh index 6b5e67ceb..aa94222ea 100644 --- a/docker/build_docker_images.sh +++ b/docker/build_docker_images.sh @@ -27,7 +27,7 @@ then GIT_BRANCH='main' # Set default argument fi -FRAMEWORKS=( "jax" "pythorch" "both" ) +FRAMEWORKS=( "jax" "pytorch") if [[ -n "$FRAMEWORK" ]]; then diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 35ac30461..1cd676d2a 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -174,7 +174,7 @@ fi # Check if arguments are valid VALID_DATASETS=("criteo1tb" "imagenet" "fastmri" "ogbg" "librispeech" \ - "wmt" "mnist") + "wmt" "mnist" "fineweb_edu_10B") VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_resnet_gelu" \ "imagenet_resnet_large_bn_init" "imagenet_vit" "imagenet_vit_glu" \ "imagenet_vit_post_ln" "imagenet_vit_map" "fastmri" "ogbg" \ @@ -185,7 +185,7 @@ VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_ "librispeech_conformer_gelu" "fastmri_model_size" "fastmri_tanh" \ "librispeech_deepspeech_tanh" \ "librispeech_deepspeech_no_resnet" "librispeech_deepspeech_norm_and_spec_aug" - "fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size") + "fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size" "finewebedu_lm") VALID_RULESETS=("self" "external") # Set data and experiment paths @@ -221,7 +221,7 @@ TUNING_RULESET_FLAG="--tuning_ruleset=${TUNING_RULESET}" if [[ "${FRAMEWORK}" == "jax" ]]; then COMMAND_PREFIX="python" else - COMMAND_PREFIX="torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8" + COMMAND_PREFIX="torchrun --redirects 1:0,2:0,3:0 --standalone --nnodes=1 --nproc_per_node=4" fi # Set data directory and bucket (bucket is only relevant in internal mode) diff --git a/docs/DOCUMENTATION.md b/docs/DOCUMENTATION.md index f7ac5e659..49e738408 100644 --- a/docs/DOCUMENTATION.md +++ b/docs/DOCUMENTATION.md @@ -55,7 +55,7 @@ The **AlgoPerf: Training Algorithms benchmark** challenges participants to submi The benchmarking process follows these **key principles**: -- 🎯 **Fixed Target, Model & Hardware:** Submitted training algorithms must train a set of [**fixed models**](#workloads) to a pre-defined validation performance target as fast as possible. All submissions use the same model architecture and are run on the same [**standardized hardware**](#benchmarking-hardware) (currently `8x NVIDIA V100 GPUs`). This isolates the training algorithm's performance and allows a fair apples-to-apples comparison. +- 🎯 **Fixed Target, Model & Hardware:** Submitted training algorithms must train a set of [**fixed models**](#workloads) to a pre-defined validation performance target as fast as possible. All submissions use the same model architecture and are run on the same [**standardized hardware**](#benchmarking-hardware) (currently `4x NVIDIA A100 GPUs`). This isolates the training algorithm's performance and allows a fair apples-to-apples comparison. - ⏱️ **Time-To-Result:** Submissions are evaluated based on the total wall-clock time required to reach the target, rewarding practical and efficient algorithms. - 🧠 **Diverse Workloads:** The benchmark includes [**8 diverse deep learning workloads**](#workloads) across domains like image classification, speech recognition, and machine translation. A submission's score is computed by aggregating its performance across all workloads, using [**performance profiles**](#algoperf-benchmark-score-via-integrated-performance-profiles), to ensure general-purpose algorithms. - 📦 **Fully-Specified Algorithms:** Submissions must be [**complete procedures**](#submission-api) and thus hyperparameter tuning is treated as part of the algorithm. Depending on the [**ruleset**](#tuning-rulesets), submissions may use parallel tuning resources. This ensures that the benchmark measures the _total_ practical cost of a training algorithm and provides practitioners with a complete method, eliminating the guesswork of how to apply it. @@ -542,7 +542,7 @@ All officially scored runs will be performed on the same benchmarking hardware t This benchmarking hardware is chosen to be easily accessible via common cloud computing providers and will likely change with each iteration of the benchmark. The specs of the benchmarking hardware for this iteration of the benchmark are: -- 8× NVIDIA V100 (16 GB) GPUs +- 4× NVIDIA A100 (40 GB) GPUs - 240 GB in RAM - 2 TB in storage (for datasets). @@ -595,7 +595,7 @@ Furthermore, all submitters must sign the following agreements:
My machine only has one GPU. How can I use this repo? -> You can run this repo on a machine with an arbitrary number of GPUs. However, the default batch sizes of our algorithms collection (e.g. `algorithms/`) are tuned for a machine with 8× NVIDIA V100 (16 GB) GPUs. You may run into OOMs if you run these algorithms with fewer than 8 GPUs. If you run into these issues because you are using a machine with less total GPU memory, please reduce the batch sizes for the submission. Note that your final submission must 'fit' on the [**benchmarking hardware**](#benchmarking-hardware), so if you are using fewer GPUs with higher per-GPU memory, please monitor your memory usage to make sure it will fit on 8× NVIDIA V100 GPUs with 16 GB of VRAM per card. +> You can run this repo on a machine with an arbitrary number of GPUs. However, the default batch sizes of our algorithms collection (e.g. `algorithms/`) are tuned for a machine with 4× NVIDIA A100 (40 GB) GPUs. You may run into OOMs if you run these algorithms with fewer than 8 GPUs. If you run into these issues because you are using a machine with less total GPU memory, please reduce the batch sizes for the submission. Note that your final submission must 'fit' on the [**benchmarking hardware**](#benchmarking-hardware), so if you are using fewer GPUs with higher per-GPU memory, please monitor your memory usage to make sure it will fit on 4× NVIDIA A100 GPUs with 40 GB of VRAM per card.
diff --git a/pyproject.toml b/pyproject.toml index e4de98f89..534f5d678 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,16 +105,15 @@ jax_cpu = [ jax_gpu = [ "jax[cuda12]==0.7.0", "algoperf[jax_core_deps]", - "nvidia-cudnn-cu12==9.10.2.21", # temporary workaround for https://github.com/jax-ml/jax/issues/30663 ] pytorch_cpu = [ - "torch==2.5.1", - "torchvision==0.20.1" + "torch==2.9.0", + "torchvision==0.24.0" ] pytorch_gpu = [ - "torch==2.5.1", - "torchvision==0.20.1", + "torch==2.9.0", + "torchvision==0.24.0", ] # Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA. ############################################################################### diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 4f2ae9c57..043a65791 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -71,6 +71,7 @@ 'wer', 'l1_loss', 'loss', + 'ppl', ] MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu'] diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 3423df2e1..4b7bed2b5 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -123,6 +123,8 @@ def get_summary_df(workload, workload_df, include_test_split=False): workload_df['accumulated_submission_time'] / workload_df['global_step'] ).iloc[-1][-1] + summary_df['step_hint'] = scoring_utils.get_workload_stephint(workload) + # test metrics if include_test_split: test_metric, test_target = scoring_utils.get_workload_metrics_and_targets( @@ -157,7 +159,7 @@ def get_summary_df(workload, workload_df, include_test_split=False): return summary_df -def get_submission_summary(df, include_test_split=True): +def get_submission_summary(df, include_test_split=False): """Summarizes the submission results into metric and time tables organized by workload. """ diff --git a/scoring/scoring_utils.py b/scoring/scoring_utils.py index 5be6c790c..cb63eab4b 100644 --- a/scoring/scoring_utils.py +++ b/scoring/scoring_utils.py @@ -240,3 +240,23 @@ def get_workload_metrics_and_targets(workload, split='validation'): metric = f'test/{metric_name}' target = workload_obj.test_target_value return metric, target + + +def get_workload_stephint(workload): + workload_name = re.match(WORKLOAD_NAME_PATTERN, workload).group(1) + framework = re.match(WORKLOAD_NAME_PATTERN, workload).group(2) + workload_metadata = copy.copy(WORKLOADS[workload_name]) + + # Extend path according to framework. + workload_metadata['workload_path'] = os.path.join( + BASE_WORKLOADS_DIR, + workload_metadata['workload_path'] + f'{framework}', + 'workload.py', + ) + workload_init_kwargs = {} + workload_obj = workloads_registry.import_workload( + workload_path=workload_metadata['workload_path'], + workload_class_name=workload_metadata['workload_class_name'], + workload_init_kwargs=workload_init_kwargs, + ) + return workload_obj.step_hint diff --git a/scoring/utils/run_workloads.py b/scoring/utils/run_workloads.py index 273881c5a..d8e0172fa 100644 --- a/scoring/utils/run_workloads.py +++ b/scoring/utils/run_workloads.py @@ -241,7 +241,8 @@ def main(_): # For each runnable workload check if there are any containers running and if not launch next container command for workload in workloads: - run_key = prng.fold_in(rng_subkey, hash(workload)) + workload_foldin = hash(workload) % 9 + run_key = prng.fold_in(rng_subkey, workload_foldin) run_seed = run_key[0] # arbitrary base_workload_name = get_base_workload_name(workload) wait_until_container_not_running() @@ -270,6 +271,7 @@ def main(_): 'docker run -t -d -v /home/kasimbeg/data/:/data/ ' '-v /home/kasimbeg/experiment_runs/:/experiment_runs ' '-v /home/kasimbeg/experiment_runs/logs:/logs ' + '-v /home/kasimbeg/algorithmic-efficiency:/algorithmic-efficiency ' f'{mount_repo_flag}' '--gpus all --ipc=host ' f'{docker_image_url} ' diff --git a/scoring/utils/workload_metadata_external_tuning.json b/scoring/utils/workload_metadata_external_tuning.json index c7d4ae195..3d9f78ca1 100644 --- a/scoring/utils/workload_metadata_external_tuning.json +++ b/scoring/utils/workload_metadata_external_tuning.json @@ -24,7 +24,7 @@ "dataset": "librispeech" }, "criteo1tb": { - "max_steps": 10666, + "max_steps": 15666, "dataset": "criteo1tb" }, "librispeech_conformer": { diff --git a/submission_runner.py b/submission_runner.py index 552c99b79..d15bda74b 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -256,7 +256,6 @@ def train_once( 'librispeech_conformer', 'ogbg', 'criteo1tb', - 'imagenet_vit', 'librispeech_deepspeech', ] eager_backend_workloads = [] @@ -266,6 +265,7 @@ def train_once( 'librispeech_deepspeech', 'ogbg', 'wmt', + 'imagenet_vit', ] base_workload = workloads.get_base_workload_name(workload_name) if base_workload in compile_error_workloads: @@ -352,7 +352,6 @@ def train_once( log_dir, flags.FLAGS, hyperparameters ) workload.attach_metrics_logger(metrics_logger) - global_start_time = get_time() train_state['last_step_end_time'] = global_start_time diff --git a/tests/test_step_times.py b/tests/test_step_times.py new file mode 100644 index 000000000..22868d67d --- /dev/null +++ b/tests/test_step_times.py @@ -0,0 +1,199 @@ +"""Tests that JAX and PyTorch step times are within 20% of each other. + +This test runs each workload for a number of steps with both JAX and PyTorch, +captures the step_time_ms metric, and asserts they are within 20%. +""" + +import re +import subprocess +import sys +import tempfile +from pathlib import Path + +from absl import flags, logging +from absl.testing import absltest, parameterized + +FLAGS = flags.FLAGS +FLAGS(sys.argv) + +MAX_STEPS = 101 +TOLERANCE = 0.25 + +WORKLOADS = [ + 'imagenet_vit', +] + +DATA_DIRS = { + 'imagenet_resnet': '/opt/data/imagenet/', + 'imagenet_vit': '/opt/data/imagenet/', + 'librispeech_conformer': '/opt/data/librispeech', + 'librispeech_deepspeech': '/opt/data/librispeech', + 'criteo1tb': '/opt/data/criteo1tb', + 'fastmri': '/opt/data/fastmri', + 'ogbg': '/opt/data/ogbg', + 'wmt': '/opt/data/wmt', +} + +CONDA_ENVS = { + 'jax': 'ap11_jax', + 'pytorch': 'ap11_torch_latest', +} + + +def get_data_dir(workload: str, framework: str) -> str: + """Map workload to its data directory.""" + base_dir = DATA_DIRS.get(workload, '/opt/data') + if workload in ['imagenet_resnet', 'imagenet_vit']: + return base_dir + framework + return base_dir + + +def run_workload(workload: str, framework: str, output_file: Path) -> bool: + """Run a workload and capture output to file.""" + data_dir = get_data_dir(workload, framework) + experiment_dir = tempfile.mkdtemp(prefix=f'{workload}_{framework}_') + + submission_path = ( + f'algorithms/baselines/external_tuning/{framework}_nadamw_full_budget.py' + ) + tuning_search_space = ( + 'algorithms/baselines/external_tuning/tuning_search_space.json' + ) + + if framework == 'jax': + cmd = [ + 'python', + 'submission_runner.py', + f'--framework={framework}', + f'--workload={workload}', + f'--data_dir={data_dir}', + f'--experiment_dir={experiment_dir}', + f'--experiment_name={workload}_benchmark', + f'--submission_path={submission_path}', + f'--tuning_search_space={tuning_search_space}', + f'--max_global_steps={MAX_STEPS}', + '--skip_evals', + '--nosave_checkpoints', + '--nosave_intermediate_checkpoints', + ] + else: + cmd = [ + 'torchrun', + '--nproc_per_node=4', + '--standalone', + 'submission_runner.py', + f'--framework={framework}', + f'--workload={workload}', + f'--data_dir={data_dir}', + f'--experiment_dir={experiment_dir}', + f'--experiment_name={workload}_benchmark', + f'--submission_path={submission_path}', + f'--tuning_search_space={tuning_search_space}', + f'--max_global_steps={MAX_STEPS}', + '--skip_evals', + '--nosave_checkpoints', + '--nosave_intermediate_checkpoints', + ] + + conda_env = CONDA_ENVS[framework] + activate_cmd = ( + f'source $(conda info --base)/etc/profile.d/conda.sh && ' + f'conda activate {conda_env} && ' + ) + full_cmd = activate_cmd + ' '.join(cmd) + + logging.info(f'Running: {workload} with {framework}') + logging.info(f'Output will be saved to: {output_file}') + + with open(output_file, 'w') as f: + result = subprocess.run( + full_cmd, + shell=True, + executable='/bin/bash', + stdout=f, + stderr=subprocess.STDOUT, + cwd=str(Path(__file__).parent.parent), + ) + + return result.returncode == 0 + + +def parse_step_time(output_file: Path) -> float | None: + """Parse the last step_time_ms from output file.""" + if not output_file.exists(): + return None + + with open(output_file, 'r') as f: + content = f.read() + + # Find all step_time_ms values + # Pattern matches: step_time_ms=123.456 or 'step_time_ms': 123.456 + pattern = r'step_time_ms[=:]\s*([\d.]+)' + matches = re.findall(pattern, content) + + if matches: + # Return the last value (most recent EMA) + return float(matches[-1]) + return None + + +named_parameters = [ + dict(testcase_name=workload, workload=workload) for workload in WORKLOADS +] + + +class StepTimeTest(parameterized.TestCase): + """Tests that JAX and PyTorch step times are within tolerance.""" + + @parameterized.named_parameters(*named_parameters) + def test_step_times_within_tolerance(self, workload): + """Test that JAX and PyTorch step times are within 20% of each other.""" + results = {} + + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + + for framework in ['jax', 'pytorch']: + output_file = tmpdir / f'{workload}_{framework}.out' + + success = run_workload(workload, framework, output_file) + self.assertTrue(success, f'Failed to run {workload} with {framework}') + + step_time = parse_step_time(output_file) + self.assertIsNotNone( + step_time, + f'Could not parse step_time_ms for {workload} with {framework}', + ) + + results[framework] = step_time + logging.info(f'{workload} {framework}: {step_time:.2f} ms') + + jax_time = results['jax'] + pytorch_time = results['pytorch'] + ratio = pytorch_time / jax_time + + logging.info( + f'{workload}: JAX={jax_time:.2f}ms, PyTorch={pytorch_time:.2f}ms, ' + f'ratio={ratio:.2f}' + ) + + # Check that ratio is within tolerance (0.8 to 1.2 for 20% tolerance) + lower_bound = 1.0 - TOLERANCE + upper_bound = 1.0 + TOLERANCE + + self.assertGreaterEqual( + ratio, + lower_bound, + f'{workload}: PyTorch is more than {TOLERANCE * 100:.0f}% faster than JAX ' + f'(ratio={ratio:.2f}, expected >= {lower_bound:.2f})', + ) + self.assertLessEqual( + ratio, + upper_bound, + f'{workload}: PyTorch is more than {TOLERANCE * 100:.0f}% slower than JAX ' + f'(ratio={ratio:.2f}, expected <= {upper_bound:.2f})', + ) + + +if __name__ == '__main__': + absltest.main()