Skip to content
Snippets Groups Projects
Select Git revision
  • main
1 result

evaluate_metrics.py

Blame
  • adaptive_computation_time.py 7.72 KiB
    """
    Adaptive Computation Time algorithm as introduced in
    https://arxiv.org/abs/1603.08983, modified for Transformers as in
    https://arxiv.org/abs/1807.03819.
    """
    
    from typing import Any, cast, Optional, Tuple
    
    import torch
    import torch as th
    
    from ..typing_utils import ActivationFn
    
    
    class AdaptiveComputationTime(th.nn.Module):
        """Wrap the given module in the Adaptive Computation Time (ACT)
        algorithm.
    
        Its outputs will be re-fed into itself until a halting unit decides
        that the outputs are to be returned.
        """
    
        def __init__(
                self,
                model: th.nn.Module,
                d_state: int,
                max_steps: int = 100,
                activation: ActivationFn = th.sigmoid,
                update_method: str = 'mean_field',  # interpolate
                is_global: bool = False,
                bias: bool = True,
                eps: float = 1e-2,
                batch_first: bool = True,
                device: Optional[torch.device] = None,
                dtype: Optional[torch.dtype] = None,
        ) -> None:
            super().__init__()
            assert 0 < eps < 1, '`eps` must be between 0 and 1.'
            assert max_steps > 0, 'have to take at least one step.'
            self.max_steps = max_steps
            self.update_method = update_method
            self.is_global = is_global
            self.eps = eps
            self.batch_first = batch_first
    
            self.model = model
            self.halting_unit = HaltingUnit(
                d_state,
                activation=activation,
                bias=bias,
                device=device,
                dtype=dtype,
            )
    
        @property
        def halting_threshold(self) -> float:
            return 1.0 - self.eps
    
        def _initialize_buffers(
                self,
                x: torch.Tensor,
        ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
            # Remove embedding dimension.
            halting_helpers_shape = x.size()[:-1] + (1,)
            halting_activations_sum = th.zeros(
                halting_helpers_shape, device=x.device, dtype=x.dtype)
            is_halted = th.zeros(
                halting_helpers_shape, device=x.device, dtype=th.bool)
            halting_step = th.empty(
                halting_helpers_shape, device=x.device, dtype=th.int64)
            mean_outputs = th.zeros(
                halting_helpers_shape, device=x.device, dtype=x.dtype)
            return halting_activations_sum, is_halted, halting_step, mean_outputs
    
        def _get_last_steps(
                self,
                step: int,
                halting_activations_sum: torch.Tensor,
        ) -> torch.Tensor:
            # Since `step` starts at zero and it we are currently taking it,
            # subtract one.
            if step >= self.max_steps - 1:
                min_steps = th.full(
                    halting_activations_sum.size(), self.max_steps - 1)
            else:
                min_steps = th.where(
                    halting_activations_sum >= self.halting_threshold,
                    step,
                    -1,
                )
            return min_steps
    
        def _remainder_prob(
                self,
                prev_halting_activations_sum: torch.Tensor,
        ) -> torch.Tensor:
            return 1.0 - prev_halting_activations_sum
    
        def _step(
                self,
                state: torch.Tensor,
                step: int,
                is_halted: torch.Tensor,
                prev_halting_activations_sum: torch.Tensor,
        ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            halting_activations = self.halting_unit(state)
            if self.is_global:
                if self.batch_first:
                    seq_dim = 2
                else:
                    seq_dim = 0
                halting_activations = halting_activations.mean(
                    seq_dim, keepdim=True)
            halting_activations_sum = \
                prev_halting_activations_sum + halting_activations
    
            is_last_step = cast(
                th.Tensor,
                step == self._get_last_steps(step, halting_activations_sum),
            )
            halting_probs = th.where(
                is_halted,
                0.0,
                th.where(
                    is_last_step,
                    self._remainder_prob(prev_halting_activations_sum),
                    halting_activations,
                ),
            )
            return halting_probs, is_last_step, halting_activations_sum
    
        def _update_state(
                self,
                total_state: torch.Tensor,
                new_state: torch.Tensor,
                halting_probs: torch.Tensor,
        ) -> torch.Tensor:
            if self.update_method == 'mean_field':
                total_state = total_state + halting_probs * new_state
            elif self.update_method == 'interpolate':
                total_state = (
                    (1 - halting_probs) * total_state
                    + halting_probs * new_state
                )
    
            return total_state
    
        def forward(
                self,
                x: torch.Tensor,
                *args: Any,
                **kwargs: Any,
        ) -> Tuple[torch.Tensor, Any, torch.Tensor]:
            """Ponder on a sequence and return the final accumulated output,
            all final outputs without modification, and the number of update
            steps.
    
            Args:
                x: Sequence to ponder on.
                args: Any other arguments to pass to each model forward
                    pass.
                kwargs: Any other keyword arguments to pass to each model
                    forward pass.
    
            Shape:
                - x: `(S, E)` for unbatched input, `(N, S, E)` if
                  `batch_first=True` or `(S, N, E)` if `batch_first=False`.
    
                - output: `(S, E)` for unbatched input, `(N, S, E)` if
                  `batch_first=True` or `(S, N, E)` if `batch_first=False`.
    
                where S is the sequence length, N is the batch size, E is
                the number of features.
            """
            (
                halting_activations_sum,
                is_halted,
                halting_step,
                mean_outputs,
            ) = self._initialize_buffers(x)
            # halting_probs_sum = th.zeros(
            #     halting_helpers_shape, device=x.device, dtype=x.dtype)
    
            outputs = x
            for step in range(self.max_steps):
                all_outputs = self.model(outputs, *args, **kwargs)
                outputs = all_outputs[0]
    
                halting_probs, is_last_step, halting_activations_sum = self._step(
                    outputs,
                    step,
                    is_halted,
                    halting_activations_sum,
                )
                halting_step = th.where(
                    is_last_step & ~is_halted, step, halting_step)
                # halting_probs_sum += halting_probs
                is_halted = is_halted | is_last_step
    
                mean_outputs = self._update_state(
                    mean_outputs, outputs, halting_probs)
    
                if is_halted.all():
                    break
    
            # assert th.allclose(
            #     halting_probs_sum, th.ones_like(halting_probs_sum))
            return mean_outputs, all_outputs, halting_step
    
    
    class HaltingUnit(th.nn.Module):
        def __init__(
                self,
                d_state: int,
                activation: ActivationFn = th.sigmoid,
                bias: bool = True,
                device: Optional[torch.device] = None,
                dtype: Optional[torch.dtype] = None,
        ) -> None:
            super().__init__()
            self.linear = th.nn.Linear(
                d_state, 1, bias=bias, device=device, dtype=dtype)
            self.activation = activation
    
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            """
            Args:
                x: Sequence to give halting activations for.
    
            Shape:
                - x: `(S, E)` for unbatched input, `(N, S, E)` if
                  `batch_first=True` or `(S, N, E)` if `batch_first=False`.
    
                - output: `(S, E)` for unbatched input, `(N, S, E)` if
                  `batch_first=True` or `(S, N, E)` if `batch_first=False`.
    
                where S is the sequence length, N is the batch size, E is
                the number of features.
            """
            return self.activation(self.linear(x))