Track non-learnable parameters in PyTorch

Hello.

I have a model in which some parameters are not learnable (e.g a PCA projection). I would like to know if there is a way to to track non-learnable parameters in PyTorch nn.Module so that when I execute my_model.to(device) all these parameters are moved between devices.

Why I need this behavior? Well my models are trained on GPU but evaluated on CPU for safeness, so that I dont run out of memory. This is because during evaluation I am performing a Monte Carlo expectation hence the predictive distributions use more samples during test than during training.

Thank you

My solution has been to add this method:

def _custom_to(self,device):
    self.to(cg.device)
    for mod in self.modules():
        try:
            mod.__move_to_device__(device)
        except:
            pass

where the Module class has now a mod.__move_to_device__ that moves to the desired device the desired variables (in my case the non learnable parameter). I am not sure if there is a better way to do this. Also not sure if it worth working on incorporating this feature to the library. It would be as easy as adding the _custom_to() method and the __move_to__device. The later calls pass by default and can be overwritten by the user as I have done. The other solution is to call __move_to__device directly inside the standard nn.Module.to() method (which I think is better)

Hi,

This is what buffers are for.
You can do self.register_buffer("foo", your_tensor). You can then use it as if you did self.foo = your_tensor.
And it will be moved around for you and saved into the state dict as well.

2 Likes

Thanks.

But if saved into the state dict then it means I have to remove it when passing nn.Parameters() into an optimizer right? As far as I remember the optimizer raise error for tensors that do not require grad.

You pass mod.parameters() to the optimizer right? While the state dict is mod.state_dict() which is what you save. And what you save contains the buffers.

Usual buffers we use are for example the running statistics of batchnorm.

2 Likes

Yes, well in my case I don’t really need to save this parameter into my state_dict. I just need that this parameter is moved between devices when call model.to(device) and that is not passed to the optimizer when calling model.parameters().

I could workaround this issue if I run in the forward of my nn.Module something like:

W = self.W.to(self.device)
Y = X@W # here X can be in GPU or CPU depending on the iteration 

But I would prefer not calling the method to.(self.device) on each forward and just handle the device of each of the parameters through a call to my_model.to()

Again, buffers are never in the model.parameters() !! only parameters are there.

But if you don’t want a buffer to be part of the state dict (which is different from .parameters()!) you can just give persistent=False when you create it :smiley:
You can check the doc for more details.

3 Likes

Thank you. Sorry I was missunderstanding. Now is clear

1 Like