What is the difference between `register_buffer` and `register_parameter` of `nn.Module`

I was reading the code of mask-rcnn to see how they fix their bn parameters. I notice that they use self.register_buffer to create the weight and bias, while, in the pytorch BN definition, self.register_parameter is used when affine=True. Could I simply think that buffer and parameter have everything in common except that buffer will neglect the operations to compute grad and update its values ?

By the way, what is the different between directly defining a nn.Paramter in the module and using register_parameter ?

29 Likes

Yes, you are correct in your assumption. If you have parameters in your model, which should be saved and restored in the state_dict, but not trained by the optimizer, you should register them as buffers.
Buffers won’t be returned in model.parameters(), so that the optimizer won’t have a change to update them.

Both approaches work the same regarding training etc.
There are some differences in the function calls however. Using register_parameter you have to pass the name as a string, which can make the creation of a range of parameters convenient. Besides that I think it’s just coding style which one you prefer.

108 Likes

If I have some parameters that I don’t want to be trained, can I just add them as self.some_params inside the nn.Module to preserve state? Does register_buffer do anything special in that case as compared to just storing it inside self?

6 Likes

@ pechyonkin

If your self.some_params are nn.Parameter objects, then you don’t have to worry about this. If they’re tensors, then they won’t be in the state_dict (unless registered as buffer).

7 Likes

What are the downsides of not using a buffer? I am currently using self.some_param inside nn.Module to keep a tensor that keeps track of running average statistics of activations. I don’t need it for backprop, only to make decisions during runtime. I want to learn more about why my approach is not an optimal one. If you could explain or give some readings, that’d be great.

5 Likes

Do you want it in the state_dict?

I am sorry if this is a stupid question, but I am not sure if I want that. I checked this, but I still don’t see why I would need that. Would I need buffers if I want to save the model later? Are there any other reasons I would like to use state_dict rather than just assigning to self?

1 Like

As @pierrecurie explained, one reason to register the tensor as a buffer is to be able to serialize the model and restore all internal states.
Another one is that all buffers and parameters will be pushed to the device, if called on the parent model:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.my_tensor = torch.randn(1)
        self.register_buffer('my_buffer', torch.randn(1))
        self.my_param = nn.Parameter(torch.randn(1))
        
    def forward(self, x):
            return x

model = MyModel()
print(model.my_tensor)
> tensor([0.9329])
print(model.state_dict())
> OrderedDict([('my_param', tensor([-0.2471])), ('my_buffer', tensor([1.2112]))])

model.cuda()
print(model.my_tensor)
> tensor([0.9329])
print(model.state_dict())
> OrderedDict([('my_param', tensor([-0.2471], device='cuda:0')), ('my_buffer', tensor([1.2112], device='cuda:0'))])

As you can see, model.my_tensor is still on the CPU, where is was created, while all parameters and buffers were pushed to the GPU after calling model.cuda().

65 Likes

Thanks for clarification! Now it makes total sense. I will actually use buffer since I am going to use GPU at some point for the model I am building.

2 Likes

@ptrblck probably another dumb question, but why wouldn’t I just use nn.Parameter for both my_tensor and my_param and just state ‘requires_grad=False’ for the first? How would that be different to the example in your post?

3 Likes

I think there wouldn’t be a difference regarding the model training, gradient flow etc., so you could probably use this approach.
However, it might be confusing to other users who are using your code to see some “buffers” in model.parameters(). :wink:
Also, you would pass these buffers to the optimizer, if you just pass all model.parameters().
Again, this won’t mess with your training, but the optimizer will unnecessarily have to skip these buffers in its step() method.

I would describe it as a “clean” code style to separate buffers and parameters.

15 Likes

Ah, thanks. An example where I find this distinction difficult is in the context of fixed positional encodings in the Transformer model. Typically I see implementations where the fixed positional encodings are registered as buffers but I’d consider these tensors as non-learnable parameters (that should show up in the list of model parameters), especially when comparing between methods that don’t rely on such injection of fixed tensors.

Re. your last remark, I guess this should do the trick, but from that thread I understand it is poor coding practice.

So in general:
buffers = ‘fixed tensors / non-learnable parameters / stuff that does not require gradient’
parameters = ‘learnable parameters, requires gradient’

5 Likes

Sort of hijacking the thread, but I am struggling at implementing capsule net, there is a need for some non-trainable variables, and unwanted in case of state_dict. Since those are just computed statistics.
So the problem is those variables are in the model code, which I use code like torch.zeros(b, h, w).cuda().
But this is ugly, and if use ‘torch.zeros(b, h ,w)’, these variables will not be sent to GPU as we do model.to(device).
Please let me know if there is a better way to construct them.:grimacing:

Could you describe the usage of these tensors a bit?
I assume they are not defining the model state, as you don’t want to have them in the state_dict, which means these tensors are independent of the model?
Could you create these tensors then during runtime, e.g. by using the device attribute of a parameter or buffer?

1 Like

Do you mean something like

independent_tensor = torch.zeros(3, 3).to(feat_map.device)

yeah, it’s a nicer workaround. Thanks.
But it will be better if there is a way to do this without setting device in the model part of code. So the whole model can be send to GPU or CPU as we set model.to(device)

model.to() transfers all “states” to the specified device.
However, your use case seems as if the mentioned tensors should not be in the state_dict, which seems like a special use case.
Could you therefore explain the use case a bit, i.e.:

  • what are these tensors used for
  • are they specific to the model
  • how do you create them (model dependent or not?)

sorry about the vagueness.
In the example of capsule net:

a example of these tensors in capsule net implementation

  • there tensors are used for computing coefficients assigning to feature maps which produce these coefficients under torch.no_grad().

  • these tensors have nothing to do with training or learning. Just some values computed by a certain procedure (dynamic routing in capsule net).

  • can simply seen as computing the cosine similarity of certain layers’ feature maps.

And as you mentioned:

model.to() transfers all “states” to the specified device.

It seems model.to() only cares about the “states”. Maybe some_tensor.to(feature_map.device) is the best we can get.:grimacing:

1 Like

You could overwrite the to or apply methods for your module to include transferring that specific tensor. This way you would not have to pass the device to any additional parts of your module.

Hi, one more question:
I have a huge tensor (700MB, precomputed, requires_grad=False) which is used for tensor multiplication somewhere as a Module (as shown in the snippet)

When training the model with multiple GPUs, I need to push it to all GPUs. The easiest way would be using regist_buffer in a module. However this means the stat_dict would be larger than 700MB ( definitely not a good idea). So I was wondering the best way to push such a large tensor to all GPUs?

BTW, if I simply use “tensor.to(device)” , is the tensor gonna be pushed to all GPUs or only the default one? (Had a test, seems like it is on the defalult gpu not all gpus.)

Thanks in advance!

class NewModule(nn.Module):

    def __init__(self, pre_matrix):
        super(NewModule, self).__init__()
        # Pre_matrix: NXP, of size 700MB, requires_grad=False
        self.pre_matrix = pre_matrix
        self.pre_matrix.requires_grad=False
        # self.register_buffer('pre_matrix', pre_matrix)  ### this means the stat_dic is larger than 700MB

    def forward(self, input):
        # input: MXN, on multiple gpus
        # output: MXP, on multiple gpus
        out = input @ self.Pre_matrix
        return out
1 Like

You could still use register_buffer and set persistent to False, which won’t add this buffer to the state_dict as described in the docs.

7 Likes