I can see two benefits of registering tensors as buffers:
- They get broadcast from rank 0 to other ranks in DPP (useful for keeping statistics in sync for BatchNorm etc)
- They are part of the state_dict and so get cast to the same dtype and device as the model with
model.to()
For me scenario 2 is perhaps most useful as there are often cases where I have some non-learnable constants that I want to ensure mirror the dtype/device of the model commands (to avoid additional copies at runtime when casting / changing device). However this sometimes clashes with scenario 1 with implicit synchronisation across ranks which is sometimes not appropriate.
This can be avoided with broadcast_buffers=False
in DDP, but then I have to manually set collectives for the buffers I do want to sync.
Is there a more canonical approach to decoupling these behaviours, and getting the benefits of scenario 2 without falling victim to scenario 1?