Intended usage of PyTorch buffers

I can see two benefits of registering tensors as buffers:

  1. They get broadcast from rank 0 to other ranks in DPP (useful for keeping statistics in sync for BatchNorm etc)
  2. They are part of the state_dict and so get cast to the same dtype and device as the model with model.to()

For me scenario 2 is perhaps most useful as there are often cases where I have some non-learnable constants that I want to ensure mirror the dtype/device of the model commands (to avoid additional copies at runtime when casting / changing device). However this sometimes clashes with scenario 1 with implicit synchronisation across ranks which is sometimes not appropriate.

This can be avoided with broadcast_buffers=False in DDP, but then I have to manually set collectives for the buffers I do want to sync.

Is there a more canonical approach to decoupling these behaviours, and getting the benefits of scenario 2 without falling victim to scenario 1?

Is the alternative here to use a Parameter instead of a buffer, but set requires_grad=False? Could there be other gotchas here to using a non-trainable parameter for a constant?