nn.register_buffer vs torch declaration performance difference

class VisionTransformer(nn.Module):
""" Vision Transformer

A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
    - https://arxiv.org/abs/2010.11929

Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
    - https://arxiv.org/abs/2012.12877
"""

def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
             num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
             drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
             act_layer=None, weight_init=''):
    super().__init__()
    self.num_classes = num_classes
    self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
    self.num_tokens = 2 if distilled else 1
    norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
    act_layer = act_layer or nn.GELU
    self.register_buffer('matrix', torch.zeros(12, 197, 192)) # This point

This is a vision transformer code, and the performance difference between declaring self.register_buffer (‘matrix’, torch.zeros (12, 197, 192) and declaring self.matrix = torch.zeros (12, 197, 192).cuda(). I want to use register_buffer to store values in checkpoints, but it performs better when I don’t use register_buffer.

There is only one line to calculate the matrix above, and you do it with torch.no _grad().

I don’t think the gradient is flowing, but what’s the difference in performance? I just don’t understand…

I’m using ddp and is there a problem with this?

Buffers are not trainable so the gradients are irrelevant.

DDP will not share the buffer if it wasn’t properly registered (as is the case in your second approach). However, since the buffer is initialized with static zeros, there should be no difference.

Could you post a small and executable code snippet showing the difference?

Hi, ptrblck.
Thanks for response.

In both ways, updates occur in the forward of the model.

It’s like the code below.

with torch.no_grad():
  self.matrix[ind] = x.mean(dim=0).detach() * (1-alpha) + self.matrix[ind] * alpha

Other than that, no other updates are made.
It is used by adding the calculated self.matrix directly to the model’s operation, through self.matrix.detach().clone()