What does register_buffer do?

I’m working through a tutorial on transformers (Tutorial 6: Transformers and Multi-Head Attention — UvA DL Notebooks v1.2 documentation) and I came across this block of code about positional encoding.

class PositionalEncoding(nn.Module):

def __init__(self, d_model, max_len=5000):
    """
    Inputs
        d_model - Hidden dimensionality of the input.
        max_len - Maximum length of a sequence to expect.
    """
    super().__init__()

    # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0)

    # register_buffer => Tensor which is not a parameter, but should be part of the modules state.
    # Used for tensors that need to be on the same device as the module.
    # persistent=False tells PyTorch to not add the buffer to the state dict (e.g. when we save the model)
    self.register_buffer('pe', pe, persistent=False)

def forward(self, x):
    x = x + self.pe[:, :x.size(1)]
    return x

I haven’t understood what the register_buffer does, would someone be able to explain it in an easier way? I get that it’s saved alongside model parameters but isn’t included in gradient calculations - so what’s the difference between this and just setting the requires_grad to be False?

4 Likes

I think this answer can help you. Also, another one

3 Likes

register_buffer is used to store the positional encodings in state_dict() and also when you call model.cuda(), the positional encodings will move alongside the learnable parameters to the new device. keep in mind that when calling model.cuda() or model.to(device), it moves buffers and nn.Parameters and here we need “PE” encodings to compute positional encodings, so we have to register it as a buffer in order to be transferred.