|
1 | 1 | import warnings |
2 | | -from typing import Optional |
3 | 2 |
|
| 3 | +from torch import Tensor |
4 | 4 | from torch.optim import Optimizer |
5 | | -from torch.optim.lr_scheduler import LRScheduler |
| 5 | +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, LRScheduler, SequentialLR |
6 | 6 |
|
7 | 7 |
|
8 | 8 | class DummyLRScheduler(LRScheduler): |
9 | | - def __init__(self, optimizer: Optimizer, last_epoch: Optional[int] = -1): |
| 9 | + def __init__(self, optimizer: Optimizer, last_epoch: int = -1): |
10 | 10 | super().__init__(optimizer, last_epoch) |
11 | 11 |
|
12 | | - def get_lr(self) -> list[float]: |
| 12 | + def get_lr(self) -> list[float | Tensor]: |
13 | 13 | if not self._get_lr_called_within_step: # type error expected due to internal pytorch implementation |
14 | 14 | warnings.warn( |
15 | | - "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning |
| 15 | + "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", |
| 16 | + UserWarning, |
16 | 17 | ) |
17 | 18 |
|
18 | 19 | return [group["lr"] for group in self.optimizer.param_groups] |
19 | 20 |
|
20 | | - def _get_closed_form_lr(self) -> list[float]: |
| 21 | + def _get_closed_form_lr(self) -> list[float | Tensor]: |
21 | 22 | return self.base_lrs |
| 23 | + |
| 24 | + |
| 25 | +class LRSchedulerFactory: |
| 26 | + @staticmethod |
| 27 | + def get_linear_warmup_cosine_annealing_lr_scheduler( |
| 28 | + optimizer: Optimizer, |
| 29 | + warmup_steps: int, |
| 30 | + total_steps: int, |
| 31 | + initial_lr: float, |
| 32 | + final_lr: float, |
| 33 | + max_lr: float, |
| 34 | + last_epoch: int = -1, |
| 35 | + ) -> SequentialLR: |
| 36 | + if warmup_steps <= 0: |
| 37 | + raise ValueError("warmup_steps must be greater than 0.") |
| 38 | + if total_steps <= warmup_steps: |
| 39 | + raise ValueError("total_steps must be greater than warmup_steps.") |
| 40 | + |
| 41 | + if not all(base_lr == max_lr for base_lr in [group["lr"] for group in optimizer.param_groups]): |
| 42 | + raise ValueError( |
| 43 | + "All parameter groups must have the same initial_lr." |
| 44 | + "and it must be equal to the initial_lr passed to the LR scheduler factory." |
| 45 | + ) |
| 46 | + |
| 47 | + warmup_scheduler = LinearLR( |
| 48 | + optimizer=optimizer, |
| 49 | + start_factor=initial_lr / max_lr, |
| 50 | + end_factor=1, |
| 51 | + total_iters=warmup_steps, |
| 52 | + ) |
| 53 | + cosine_scheduler = CosineAnnealingLR( |
| 54 | + optimizer=optimizer, |
| 55 | + T_max=total_steps - warmup_steps, |
| 56 | + eta_min=final_lr, |
| 57 | + ) |
| 58 | + |
| 59 | + return SequentialLR( |
| 60 | + optimizer=optimizer, |
| 61 | + schedulers=[warmup_scheduler, cosine_scheduler], |
| 62 | + milestones=[warmup_steps], |
| 63 | + last_epoch=last_epoch, |
| 64 | + ) |
0 commit comments