diff --git a/tn_transformers/schedules.py b/tn_transformers/schedules.py index 4c48992383808c4ef34e94e64149fb45a42559df..fdf73048cf6ad3c5a12c924c33eca47e01860fce 100644 --- a/tn_transformers/schedules.py +++ b/tn_transformers/schedules.py @@ -21,6 +21,8 @@ class CosineAnnealingWithWarmupLR: max_steps: Number of steps to reach `end_lr` (including warm-up phase). start_lr: Initial learning rate to warm up from. end_lr: Final learning rate to anneal to after warmup. + linear_warmup: Whether to use a linear schedule for the initial + warmup. last_epoch: The index of the last step taken. Used to continue training. If -1, no step has been taken. """ @@ -32,6 +34,7 @@ class CosineAnnealingWithWarmupLR: max_steps: int, start_lr: float = 0.0, end_lr: float = 0.0, + linear_warmup: bool = False, last_epoch: int = -1, ) -> None: self.optimizer = optimizer @@ -47,6 +50,7 @@ class CosineAnnealingWithWarmupLR: self.last_epoch = last_epoch self.T_max_warmup = warmup_steps self.T_max_total = max_steps + self.linear_warmup = linear_warmup # Initial step. self.step() @@ -65,18 +69,30 @@ class CosineAnnealingWithWarmupLR: def get_lr(self) -> List[float]: if self.last_epoch <= self.T_max_warmup: - values = [ - ( - base_lr - + ( - (self.eta_max_warmup - base_lr) - * (1 + math.cos( - math.pi * self.last_epoch / self.T_max_warmup - )) / 2 + if self.linear_warmup: + values = [ + ( + base_lr + + ( + (self.eta_max_warmup - base_lr) + * (1 - (self.last_epoch / self.T_max_warmup)) + ) ) - ) - for base_lr in self.base_lrs - ] + for base_lr in self.base_lrs + ] + else: + values = [ + ( + base_lr + + ( + (self.eta_max_warmup - base_lr) + * (1 + math.cos( + math.pi * self.last_epoch / self.T_max_warmup + )) / 2 + ) + ) + for base_lr in self.base_lrs + ] elif self.last_epoch <= self.T_max_total: values = [ (