Set weights of Conv layer and make them non trainable

I’m writing a module that includes some Conv2D layers and I want to manually set their weights and make them non-trainable. My module is something like this:

import torch
import torch.nn as nn


def SetWeights():
    ## manual function to set weights
    return
## Returns a 4D tensor

class Module(nn.Module):
    def __init__(self):
        super().__init__()

        self.Conv1 = nn.Conv2d(128,128,kernel_size=(2,2),padding='same')
        self.Conv2 = nn.Conv2d(128,128,kernel_size=(3,3),padding='same')
        self.Conv3 = nn.Conv2d(128,128,kernel_size=(5,5),padding='same')
        self.Conv4 = nn.Conv2d(128,128,kernel_size=(7,7),padding='same')


    def forward():

        ## usual forward pass
        return

In Keras it’s something like:

     Conv1.set_weights([weights1])
     Conv2.set_weights([weighs2])
     Conv3.set_weights([weights3])
     Conv4.set_weights([weights4])
     Conv1.trainable = False
     Conv2.trainable = False 
     Conv3.trainable = False
     Conv4.trainable = False

How can I do that in Pytorch?

P.S. The weights are numpy tensors

1 Like

You can either assign the new weights via:

with torch.no_grad():
    self.Conv1.weight = nn.Parameter(...)
    # or
    self.Conv1.weight.copy_(tensor)

and set their .requires_grad attribute to False to freeze them or alternatively you could also directly use the functional API:

x = F.conv2d(input, self.weight)
1 Like

Okk so if I use torch.no_grad() I need to use it inside the __init__() and what about the forward() do we need to make any changes in that function?

If you freeze the parameters, you won’t have to change the forward and could just apply the conv layers.
You could try to also use no_grad() in the forward, but since the parameters should not be trained at all, the right approach would be to just freeze them.

Ok thanx I got my answer

In case I want to set each and every layer non-trainable, is there any shorthand to make it or i have wite the above line for each and every layer

To freeze the entire model, you would iterate the parameters and set their .requires_gradient attribute to False.

1 Like