diff --git a/tn_transformers/schedules.py b/tn_transformers/schedules.py
index 4c48992383808c4ef34e94e64149fb45a42559df..fdf73048cf6ad3c5a12c924c33eca47e01860fce 100644
--- a/tn_transformers/schedules.py
+++ b/tn_transformers/schedules.py
@@ -21,6 +21,8 @@ class CosineAnnealingWithWarmupLR:
         max_steps: Number of steps to reach `end_lr` (including warm-up phase).
         start_lr: Initial learning rate to warm up from.
         end_lr: Final learning rate to anneal to after warmup.
+        linear_warmup: Whether to use a linear schedule for the initial
+            warmup.
         last_epoch: The index of the last step taken. Used to continue
             training. If -1, no step has been taken.
     """
@@ -32,6 +34,7 @@ class CosineAnnealingWithWarmupLR:
             max_steps: int,
             start_lr: float = 0.0,
             end_lr: float = 0.0,
+            linear_warmup: bool = False,
             last_epoch: int = -1,
     ) -> None:
         self.optimizer = optimizer
@@ -47,6 +50,7 @@ class CosineAnnealingWithWarmupLR:
         self.last_epoch = last_epoch
         self.T_max_warmup = warmup_steps
         self.T_max_total = max_steps
+        self.linear_warmup = linear_warmup
 
         # Initial step.
         self.step()
@@ -65,18 +69,30 @@ class CosineAnnealingWithWarmupLR:
 
     def get_lr(self) -> List[float]:
         if self.last_epoch <= self.T_max_warmup:
-            values = [
-                (
-                    base_lr
-                    + (
-                        (self.eta_max_warmup - base_lr)
-                        * (1 + math.cos(
-                            math.pi * self.last_epoch / self.T_max_warmup
-                        )) / 2
+            if self.linear_warmup:
+                values = [
+                    (
+                        base_lr
+                        + (
+                            (self.eta_max_warmup - base_lr)
+                            * (1 - (self.last_epoch / self.T_max_warmup))
+                        )
                     )
-                )
-                for base_lr in self.base_lrs
-            ]
+                    for base_lr in self.base_lrs
+                ]
+            else:
+                values = [
+                    (
+                        base_lr
+                        + (
+                            (self.eta_max_warmup - base_lr)
+                            * (1 + math.cos(
+                                math.pi * self.last_epoch / self.T_max_warmup
+                            )) / 2
+                        )
+                    )
+                    for base_lr in self.base_lrs
+                ]
         elif self.last_epoch <= self.T_max_total:
             values = [
                 (