Skip to content
Snippets Groups Projects
Commit a0e0cbbf authored by Jan Ebert's avatar Jan Ebert
Browse files

Add "un-embedding" layer

A layer that uses the input embedding for obtaining output embeddings.
parent 3870da20
No related branches found
No related tags found
No related merge requests found
......@@ -19,6 +19,36 @@ class ResidualConnection(th.nn.Module):
return x + self.module(x)
class UnEmbedding(th.nn.Module):
"""An weight-tied "un-embedding" layer.
"Un-embedding" means that a tensor of shape `(E,)` will be
transformed to a tensor of shape `(N,)`, where E is the embedding
size (number of features), N is the number of embeddings. This is is
contrast to the standard embedding function, which transforms a
tensor of shape `(N,)` to a tensor of shape `(E,)`.
"""
def __init__(self, embedding: th.nn.Embedding) -> None:
self.embedding = embedding
def forward(self, x: torch.Tensor) -> torch.Tensor:
""""Un-embed" the given input tensor.
Args:
x: Tensor to un-embed.
Shape:
- x: `(..., E)`.
- output: `(..., N)`.
where E is the embedding size (number of features), N is the
number of embeddings.
"""
return x @ self.embedding.weight.t()
class ArgSelector(th.nn.Module):
"""Wraps a module to filter out a single arguments.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment