Skip to content
Snippets Groups Projects
Commit 8e5dbcda authored by Jan Ebert's avatar Jan Ebert
Browse files

Implement linear warmup flag for scheduler

Linear warmup is used often in the literature.
parent b5803738
No related branches found
No related tags found
No related merge requests found
Pipeline #154608 passed
...@@ -21,6 +21,8 @@ class CosineAnnealingWithWarmupLR: ...@@ -21,6 +21,8 @@ class CosineAnnealingWithWarmupLR:
max_steps: Number of steps to reach `end_lr` (including warm-up phase). max_steps: Number of steps to reach `end_lr` (including warm-up phase).
start_lr: Initial learning rate to warm up from. start_lr: Initial learning rate to warm up from.
end_lr: Final learning rate to anneal to after warmup. end_lr: Final learning rate to anneal to after warmup.
linear_warmup: Whether to use a linear schedule for the initial
warmup.
last_epoch: The index of the last step taken. Used to continue last_epoch: The index of the last step taken. Used to continue
training. If -1, no step has been taken. training. If -1, no step has been taken.
""" """
...@@ -32,6 +34,7 @@ class CosineAnnealingWithWarmupLR: ...@@ -32,6 +34,7 @@ class CosineAnnealingWithWarmupLR:
max_steps: int, max_steps: int,
start_lr: float = 0.0, start_lr: float = 0.0,
end_lr: float = 0.0, end_lr: float = 0.0,
linear_warmup: bool = False,
last_epoch: int = -1, last_epoch: int = -1,
) -> None: ) -> None:
self.optimizer = optimizer self.optimizer = optimizer
...@@ -47,6 +50,7 @@ class CosineAnnealingWithWarmupLR: ...@@ -47,6 +50,7 @@ class CosineAnnealingWithWarmupLR:
self.last_epoch = last_epoch self.last_epoch = last_epoch
self.T_max_warmup = warmup_steps self.T_max_warmup = warmup_steps
self.T_max_total = max_steps self.T_max_total = max_steps
self.linear_warmup = linear_warmup
# Initial step. # Initial step.
self.step() self.step()
...@@ -65,6 +69,18 @@ class CosineAnnealingWithWarmupLR: ...@@ -65,6 +69,18 @@ class CosineAnnealingWithWarmupLR:
def get_lr(self) -> List[float]: def get_lr(self) -> List[float]:
if self.last_epoch <= self.T_max_warmup: if self.last_epoch <= self.T_max_warmup:
if self.linear_warmup:
values = [
(
base_lr
+ (
(self.eta_max_warmup - base_lr)
* (1 - (self.last_epoch / self.T_max_warmup))
)
)
for base_lr in self.base_lrs
]
else:
values = [ values = [
( (
base_lr base_lr
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment