I am writing a PositionalEmbedding() module which is an implementation based on “Attention Is All You Need” using PyTorch. According to the paper, there should be no learnable attribute in the module PositionalEmbedding().
The initialization of the Embedding weights is as follows:
import math
import torch
class TrigonometricPositionalEmbedding(torch.nn.Module):
def __init__(self, embedding_number, dimension, padding_idx):
position = torch.arange(0, embedding_number).unsqueeze(1)
sin_multiplicator = torch.exp(-(math.log(10000) / dimension) * 2 * torch.arange(0, dimension, 2))
cos_multiplicator = torch.exp(-(math.log(10000) / dimension) * 2 * torch.arange(1, dimension, 2))
sin_weight = torch.sin(position * sin_multiplicator)
cos_weight = torch.cos(position * cos_multiplicator)
weight = torch.zeros(embedding_number, dimension)
weight[:, 0::2] = sin_weight
weight[:, 1::2] = cos_weight
Next up, because the weight
in the module should not be learnable, detach()
the weight
came into my mind. I think it may not be the most elegant way, so I made a further investigation. I found that there are two kinds of implementation method like what OpenNMT and FAIRSeq have done respectively.
I follow the method of OpenNMT and my implementation is as follows:
class TrigonometricPositionalEmbedding(torch.nn.Module):
def __init__(self, embedding_number, dimension, padding_idx):
...
weight[:, 1::2] = cos_weight
self.register_buffer('weight', weight)
def forward(self, position):
torch.index_select(self.weight, 0, position)
I follow the method of FAIRSeq and my implementation is as follows:
class TrigonometricPositionalEmbedding(torch.nn.Module):
def __init__(self, embedding_number, dimension, padding_idx):
...
weight[:, 1::2] = cos_weight
self.weight = weight
def forward(self, position):
torch.index_select(self.weight, 0, position).detach()
So I am curious about the difference between the two method, or what is the difference between registered buffer and detached parameter?
(The question was also posted in StackOverflow but no answer yet.)