Select Git revision
evaluate_metrics.py
adaptive_computation_time.py 7.72 KiB
"""
Adaptive Computation Time algorithm as introduced in
https://arxiv.org/abs/1603.08983, modified for Transformers as in
https://arxiv.org/abs/1807.03819.
"""
from typing import Any, cast, Optional, Tuple
import torch
import torch as th
from ..typing_utils import ActivationFn
class AdaptiveComputationTime(th.nn.Module):
"""Wrap the given module in the Adaptive Computation Time (ACT)
algorithm.
Its outputs will be re-fed into itself until a halting unit decides
that the outputs are to be returned.
"""
def __init__(
self,
model: th.nn.Module,
d_state: int,
max_steps: int = 100,
activation: ActivationFn = th.sigmoid,
update_method: str = 'mean_field', # interpolate
is_global: bool = False,
bias: bool = True,
eps: float = 1e-2,
batch_first: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__()
assert 0 < eps < 1, '`eps` must be between 0 and 1.'
assert max_steps > 0, 'have to take at least one step.'
self.max_steps = max_steps
self.update_method = update_method
self.is_global = is_global
self.eps = eps
self.batch_first = batch_first
self.model = model
self.halting_unit = HaltingUnit(
d_state,
activation=activation,
bias=bias,
device=device,
dtype=dtype,
)
@property
def halting_threshold(self) -> float:
return 1.0 - self.eps
def _initialize_buffers(
self,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# Remove embedding dimension.
halting_helpers_shape = x.size()[:-1] + (1,)
halting_activations_sum = th.zeros(
halting_helpers_shape, device=x.device, dtype=x.dtype)
is_halted = th.zeros(
halting_helpers_shape, device=x.device, dtype=th.bool)
halting_step = th.empty(
halting_helpers_shape, device=x.device, dtype=th.int64)
mean_outputs = th.zeros(
halting_helpers_shape, device=x.device, dtype=x.dtype)
return halting_activations_sum, is_halted, halting_step, mean_outputs
def _get_last_steps(
self,
step: int,
halting_activations_sum: torch.Tensor,
) -> torch.Tensor:
# Since `step` starts at zero and it we are currently taking it,
# subtract one.
if step >= self.max_steps - 1:
min_steps = th.full(
halting_activations_sum.size(), self.max_steps - 1)
else:
min_steps = th.where(
halting_activations_sum >= self.halting_threshold,
step,
-1,
)
return min_steps
def _remainder_prob(
self,
prev_halting_activations_sum: torch.Tensor,
) -> torch.Tensor:
return 1.0 - prev_halting_activations_sum
def _step(
self,
state: torch.Tensor,
step: int,
is_halted: torch.Tensor,
prev_halting_activations_sum: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
halting_activations = self.halting_unit(state)
if self.is_global:
if self.batch_first:
seq_dim = 2
else:
seq_dim = 0
halting_activations = halting_activations.mean(
seq_dim, keepdim=True)
halting_activations_sum = \
prev_halting_activations_sum + halting_activations
is_last_step = cast(
th.Tensor,
step == self._get_last_steps(step, halting_activations_sum),
)
halting_probs = th.where(
is_halted,
0.0,
th.where(
is_last_step,
self._remainder_prob(prev_halting_activations_sum),
halting_activations,
),
)
return halting_probs, is_last_step, halting_activations_sum
def _update_state(
self,
total_state: torch.Tensor,
new_state: torch.Tensor,
halting_probs: torch.Tensor,
) -> torch.Tensor:
if self.update_method == 'mean_field':
total_state = total_state + halting_probs * new_state
elif self.update_method == 'interpolate':
total_state = (
(1 - halting_probs) * total_state
+ halting_probs * new_state
)
return total_state
def forward(
self,
x: torch.Tensor,
*args: Any,
**kwargs: Any,
) -> Tuple[torch.Tensor, Any, torch.Tensor]:
"""Ponder on a sequence and return the final accumulated output,
all final outputs without modification, and the number of update
steps.
Args:
x: Sequence to ponder on.
args: Any other arguments to pass to each model forward
pass.
kwargs: Any other keyword arguments to pass to each model
forward pass.
Shape:
- x: `(S, E)` for unbatched input, `(N, S, E)` if
`batch_first=True` or `(S, N, E)` if `batch_first=False`.
- output: `(S, E)` for unbatched input, `(N, S, E)` if
`batch_first=True` or `(S, N, E)` if `batch_first=False`.
where S is the sequence length, N is the batch size, E is
the number of features.
"""
(
halting_activations_sum,
is_halted,
halting_step,
mean_outputs,
) = self._initialize_buffers(x)
# halting_probs_sum = th.zeros(
# halting_helpers_shape, device=x.device, dtype=x.dtype)
outputs = x
for step in range(self.max_steps):
all_outputs = self.model(outputs, *args, **kwargs)
outputs = all_outputs[0]
halting_probs, is_last_step, halting_activations_sum = self._step(
outputs,
step,
is_halted,
halting_activations_sum,
)
halting_step = th.where(
is_last_step & ~is_halted, step, halting_step)
# halting_probs_sum += halting_probs
is_halted = is_halted | is_last_step
mean_outputs = self._update_state(
mean_outputs, outputs, halting_probs)
if is_halted.all():
break
# assert th.allclose(
# halting_probs_sum, th.ones_like(halting_probs_sum))
return mean_outputs, all_outputs, halting_step
class HaltingUnit(th.nn.Module):
def __init__(
self,
d_state: int,
activation: ActivationFn = th.sigmoid,
bias: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__()
self.linear = th.nn.Linear(
d_state, 1, bias=bias, device=device, dtype=dtype)
self.activation = activation
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Sequence to give halting activations for.
Shape:
- x: `(S, E)` for unbatched input, `(N, S, E)` if
`batch_first=True` or `(S, N, E)` if `batch_first=False`.
- output: `(S, E)` for unbatched input, `(N, S, E)` if
`batch_first=True` or `(S, N, E)` if `batch_first=False`.
where S is the sequence length, N is the batch size, E is
the number of features.
"""
return self.activation(self.linear(x))