Skip to content
Snippets Groups Projects
Commit 8e5dbcda authored by Jan Ebert's avatar Jan Ebert
Browse files

Implement linear warmup flag for scheduler

Linear warmup is used often in the literature.
parent b5803738
No related branches found
No related tags found
No related merge requests found
Pipeline #154608 passed
......@@ -21,6 +21,8 @@ class CosineAnnealingWithWarmupLR:
max_steps: Number of steps to reach `end_lr` (including warm-up phase).
start_lr: Initial learning rate to warm up from.
end_lr: Final learning rate to anneal to after warmup.
linear_warmup: Whether to use a linear schedule for the initial
warmup.
last_epoch: The index of the last step taken. Used to continue
training. If -1, no step has been taken.
"""
......@@ -32,6 +34,7 @@ class CosineAnnealingWithWarmupLR:
max_steps: int,
start_lr: float = 0.0,
end_lr: float = 0.0,
linear_warmup: bool = False,
last_epoch: int = -1,
) -> None:
self.optimizer = optimizer
......@@ -47,6 +50,7 @@ class CosineAnnealingWithWarmupLR:
self.last_epoch = last_epoch
self.T_max_warmup = warmup_steps
self.T_max_total = max_steps
self.linear_warmup = linear_warmup
# Initial step.
self.step()
......@@ -65,6 +69,18 @@ class CosineAnnealingWithWarmupLR:
def get_lr(self) -> List[float]:
if self.last_epoch <= self.T_max_warmup:
if self.linear_warmup:
values = [
(
base_lr
+ (
(self.eta_max_warmup - base_lr)
* (1 - (self.last_epoch / self.T_max_warmup))
)
)
for base_lr in self.base_lrs
]
else:
values = [
(
base_lr
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment