From 8e5dbcdaf18f3035f1c2f2fe5d2bc3d74adf1660 Mon Sep 17 00:00:00 2001 From: janEbert <janpublicebert@posteo.net> Date: Mon, 11 Sep 2023 15:43:45 +0200 Subject: [PATCH] Implement linear warmup flag for scheduler Linear warmup is used often in the literature. --- tn_transformers/schedules.py | 38 +++++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/tn_transformers/schedules.py b/tn_transformers/schedules.py index 4c48992..fdf7304 100644 --- a/tn_transformers/schedules.py +++ b/tn_transformers/schedules.py @@ -21,6 +21,8 @@ class CosineAnnealingWithWarmupLR: max_steps: Number of steps to reach `end_lr` (including warm-up phase). start_lr: Initial learning rate to warm up from. end_lr: Final learning rate to anneal to after warmup. + linear_warmup: Whether to use a linear schedule for the initial + warmup. last_epoch: The index of the last step taken. Used to continue training. If -1, no step has been taken. """ @@ -32,6 +34,7 @@ class CosineAnnealingWithWarmupLR: max_steps: int, start_lr: float = 0.0, end_lr: float = 0.0, + linear_warmup: bool = False, last_epoch: int = -1, ) -> None: self.optimizer = optimizer @@ -47,6 +50,7 @@ class CosineAnnealingWithWarmupLR: self.last_epoch = last_epoch self.T_max_warmup = warmup_steps self.T_max_total = max_steps + self.linear_warmup = linear_warmup # Initial step. self.step() @@ -65,18 +69,30 @@ class CosineAnnealingWithWarmupLR: def get_lr(self) -> List[float]: if self.last_epoch <= self.T_max_warmup: - values = [ - ( - base_lr - + ( - (self.eta_max_warmup - base_lr) - * (1 + math.cos( - math.pi * self.last_epoch / self.T_max_warmup - )) / 2 + if self.linear_warmup: + values = [ + ( + base_lr + + ( + (self.eta_max_warmup - base_lr) + * (1 - (self.last_epoch / self.T_max_warmup)) + ) ) - ) - for base_lr in self.base_lrs - ] + for base_lr in self.base_lrs + ] + else: + values = [ + ( + base_lr + + ( + (self.eta_max_warmup - base_lr) + * (1 + math.cos( + math.pi * self.last_epoch / self.T_max_warmup + )) / 2 + ) + ) + for base_lr in self.base_lrs + ] elif self.last_epoch <= self.T_max_total: values = [ ( -- GitLab