Transformer Positional Encoding Class

I’m learning a transformer implementation through this Kaggle tutorial Transformer from scratch using pytorch | Kaggle . I don’t understand several of the lines of code in the PositionalEmbedding class:

# register buffer in Pytorch ->
# If you have parameters in your model, which should be saved and restored in the state_dict,
# but not trained by the optimizer, you should register them as buffers.


class PositionalEmbedding(nn.Module):
    def __init__(self,max_seq_len,embed_model_dim):
        """
        Args:
            seq_len: length of input sequence
            embed_model_dim: demension of embedding
        """
        super(PositionalEmbedding, self).__init__()
        self.embed_dim = embed_model_dim

        pe = torch.zeros(max_seq_len,self.embed_dim)
        for pos in range(max_seq_len):
            for i in range(0,self.embed_dim,2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/self.embed_dim)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/self.embed_dim)))
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)


    def forward(self, x):
        """
        Args:
            x: input vector
        Returns:
            x: output
        """
      
        # make embeddings relatively larger
        x = x * math.sqrt(self.embed_dim)
        #add constant to embedding
        seq_len = x.size(1)
        x = x + torch.autograd.Variable(self.pe[:,:seq_len], requires_grad=False)
        return x
  • What is unsqueeze() doing? I’m guessing it adds a batch dimensions but I’m not sure…
  • Why are we using “self.register_buffer(‘pe’, pe)” ? What is this doing? Why can’t we just do “self.pe = pe”?
  • I don’t understand these 2 lines of code at all:
        seq_len = x.size(1)
        x = x + torch.autograd.Variable(self.pe[:,:seq_len]```

If anyone has experience with implementing transformers I would appreciate the help! Thanks

And to add to that, why does PositionalEmbedding class need to inherit from torch.nn.Module at all? A Positional Encoding vector just contains positional information, it shouldn’t have any learnable parameters. It’s supposed to be added to the actual learnable embedding matrix…