I’m working on an experiment where I train a model using CNN backbone which is applied on the input image and on a set of learnable prototypes. During the training process I want to update the weights of the CNN when it is applied on the image. When it is applied to the prototype I want to freeze the weights of the CNN but do update the prototypes.
I tried to set
torch.no_grad on the CNN when it is applied to the prototypes but this resulted in the prototypes not getting updated as well.
Does anyone know a method to perform this construct?
Here is a part of the code used in the experiment.
import torch import torch.nn as nn import torch.nn.functional as F class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.conv3 = nn.Conv2d(64, 64, 3, 1) self.conv4 = nn.Conv2d(64, 128, 3, 1) def forward(self, x): x = F.relu(self.conv2(F.relu(self.conv1(x)))) x = F.max_pool2d(x, 2, 2) x = F.relu(self.conv4(F.relu(self.conv3(x)))) x = F.max_pool2d(x, 2, 2) return x class Model(nn.Module): def __init__(self, n_prototypes): super(Model, self).__init__() self.backbone = CNN() self.prototype = nn.Parameter( torch.rand((n_prototypes,) + (28, 28)) ) def forward(self, x): x = self.backbone(x) # update CNN kernels y = self.backbone(self.prototype) # update prototypes but not CNN kernels return x, y
Thanks in advance.