How to modify the positional encoding in torch.nn.Transformer?

I am doing some experiments on positional encoding, and would like to use torch.nn.Transformer for my experiments.

But it seems there is no argument for me to change the positional encoding. I also cannot seem to find in the source code where the torch.nn.Transformer is handling tthe positional encoding.

How to change the default sin cos encoding to some of my custom-made encoding?

2 Likes

Hi, i’m not expert about pytorch or transformers but i think nn.Transformer doesn’t have positional encoding, you have to code yourself then to add token embeddings.

1 Like

if you are looking for a positional encoder see this:

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

https://pytorch.org/tutorials/beginner/transformer_tutorial.html

@Brando_Miranda can you explain, why we are using dropout in the positional encoding layer, it is just input to the network. I couldn’t find the reason for it.

@RAJA_PARIKSHAT I think the idea here is to ensure that downstream tasks do not overfit based on the positional encoding. As you can see the dropout is only called right before the return statement of the forward function. By tuning the parameter p in the dropout layer we can thereby influence how much information about the original position of a token is included in downstream computations (e.g. attention). This parameter seems to be especially important if we have learnable embeddings (this is not the case here).