I’m writing a module that includes some Conv2D layers and I want to manually set their weights and make them non-trainable. My module is something like this:
import torch
import torch.nn as nn
def SetWeights():
## manual function to set weights
return
## Returns a 4D tensor
class Module(nn.Module):
def __init__(self):
super().__init__()
self.Conv1 = nn.Conv2d(128,128,kernel_size=(2,2),padding='same')
self.Conv2 = nn.Conv2d(128,128,kernel_size=(3,3),padding='same')
self.Conv3 = nn.Conv2d(128,128,kernel_size=(5,5),padding='same')
self.Conv4 = nn.Conv2d(128,128,kernel_size=(7,7),padding='same')
def forward():
## usual forward pass
return
In Keras it’s something like:
Conv1.set_weights([weights1])
Conv2.set_weights([weighs2])
Conv3.set_weights([weights3])
Conv4.set_weights([weights4])
Conv1.trainable = False
Conv2.trainable = False
Conv3.trainable = False
Conv4.trainable = False
How can I do that in Pytorch?
P.S. The weights are numpy tensors