Positional Encoding Layer that allows for Batched Inputs

Hi everyone. I have written a custom module for a Positional Encoding layer in a neural network. However, this one only works for “flat” inputs x, such as individual indices, rather than e.g. entire tensors with multiple dimensions.

class PositionalEncoding(torch.nn.Module):

    def __init__(self, b: float = 1.25, l: int = 80) -> None:
        super(PositionalEncoding, self).__init__()
        self.l = int(l)
        self.b = b

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_pe = []
        for i in range(self.l):
            arg = (self.b ** i) * torch.pi * x
            x_pe.extend([torch.sin(arg), torch.cos(arg)])
        return torch.stack(x_pe, dim = 1).squeeze(-1)

My goal now is to obtain a module that can positionally encode entire batches of tensors at once. An example would e.g. be a [B, N, K] tensor, where B is the batch size. The desired output shape is [B, N, 2*K*self.l]. I have tried this for a few hours, and I seem to be no closer to a solution that is functionally equivalent to the above.

My main idea: I would store the input shape of x, and then flatten x except for the last dimension, and unsqueeze a new dimension at the second-to-last dimension for frequncies. Id then compute all the frequencies as a tensor with equally many dimensions as x, and multiple xwith the frequencies to get a tensor that is conceptually the same asargabove. I then want to take the sine and the cosine of this tensor, and stack them in some way such that for every element of the tensor that is positionally encoded, the order of the elements in its embeddings are[sin(x1), cos(x1), sin(x2), cos(x2), …]`.

Problem: I can’t find a way that achieves this interleaving of the sines and the cosines… I tried with stack(), view(), reshape(), permute(), dstack() and multiple variations and combinations thereof, but to no avail. I also tried using einops.rearrange but also did not get the desired result.

For reference, this is an approach I tried which returns something similar, but not quite correct:

def forward(self, x: torch.Tensor) -> torch.Tensor:
        # store input shape to later recover it and flatten all dimensions, except the last one,
        # and add a new dimensions for the frequencies
        old_shape = x.shape
        x_flat = x.view(-1, 1, old_shape[-1])

        # compute the exponentially increasing frequences and multiply them with the flattened input
        # using PyTorch's broadcasting functionality
        freqs = torch.pi * (self.b ** torch.arange(self.l, dtype = torch.float32, device = x.device))
        arg = x_flat * freqs.view(1, self.l, 1)

        # take the sine and the consine of the arguments, and interleave them in the last dimension
        # we want [sin(x1), cos(x1), sin(x2), cos(x2), ...] , not [sin(x1), sin(x2), ..., cos(x1), cos(x2), ...]
        # to do so, we stack the sines and the cosines in the last dimension and then flatten the
        # last two dimensions.
        x_pe_int_extradim = torch.dstack([torch.sin(arg), torch.cos(arg)])
        x_pe_int = x_pe_int_extradim.view(x_pe_int_extradim.size(0), -1)
        x_pe = x_pe_int.view(*old_shape[:-1], -1)
        return x_pe

I’d very much appreciate any help or pointers to resources that can help me!