diff --git a/lee_transformers/layers/utils.py b/lee_transformers/layers/utils.py index a10563b7275a56e18d34a3a37e45296be9892de9..738d6ddb94929c6c0aed1660ed37412d60e9064a 100644 --- a/lee_transformers/layers/utils.py +++ b/lee_transformers/layers/utils.py @@ -19,6 +19,36 @@ class ResidualConnection(th.nn.Module): return x + self.module(x) +class UnEmbedding(th.nn.Module): + """An weight-tied "un-embedding" layer. + + "Un-embedding" means that a tensor of shape `(E,)` will be + transformed to a tensor of shape `(N,)`, where E is the embedding + size (number of features), N is the number of embeddings. This is is + contrast to the standard embedding function, which transforms a + tensor of shape `(N,)` to a tensor of shape `(E,)`. + """ + + def __init__(self, embedding: th.nn.Embedding) -> None: + self.embedding = embedding + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """"Un-embed" the given input tensor. + + Args: + x: Tensor to un-embed. + + Shape: + - x: `(..., E)`. + + - output: `(..., N)`. + + where E is the embedding size (number of features), N is the + number of embeddings. + """ + return x @ self.embedding.weight.t() + + class ArgSelector(th.nn.Module): """Wraps a module to filter out a single arguments.