Skip to content
Snippets Groups Projects
Select Git revision
  • main
1 result

library.py

Blame
  • utils.py 1.06 KiB
    from typing import Any, Union
    
    import torch
    import torch as th
    
    
    class ResidualConnection(th.nn.Module):
        """Wraps a module in a residual connection.
    
        Any inputs to this module will be added as a residual to the inner
        module's output.
        """
    
        def __init__(self, module: torch.nn.Module) -> None:
            super().__init__()
            self.module = module
    
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return x + self.module(x)
    
    
    class ArgSelector(th.nn.Module):
        """Wraps a module to filter out a single arguments.
    
        From the inputs to this module, a single one will be passed to the
        inner module and its output returned. The taken argument is
        specified by `key`.
        """
    
        def __init__(self, module: torch.nn.Module, key: Union[int, str]) -> None:
            super().__init__()
            self.module = module
            self.key = key
    
        def forward(self, *args: Any, **kwargs: Any) -> Any:
            if isinstance(self.key, str):
                return self.module(kwargs[self.key])
            else:
                return self.module(args[self.key])