Select Git revision
utils.py 1.06 KiB
from typing import Any, Union
import torch
import torch as th
class ResidualConnection(th.nn.Module):
"""Wraps a module in a residual connection.
Any inputs to this module will be added as a residual to the inner
module's output.
"""
def __init__(self, module: torch.nn.Module) -> None:
super().__init__()
self.module = module
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.module(x)
class ArgSelector(th.nn.Module):
"""Wraps a module to filter out a single arguments.
From the inputs to this module, a single one will be passed to the
inner module and its output returned. The taken argument is
specified by `key`.
"""
def __init__(self, module: torch.nn.Module, key: Union[int, str]) -> None:
super().__init__()
self.module = module
self.key = key
def forward(self, *args: Any, **kwargs: Any) -> Any:
if isinstance(self.key, str):
return self.module(kwargs[self.key])
else:
return self.module(args[self.key])