Hi,
I’m working on an experiment where I train a model using CNN backbone which is applied on the input image and on a set of learnable prototypes. During the training process I want to update the weights of the CNN when it is applied on the image. When it is applied to the prototype I want to freeze the weights of the CNN but do update the prototypes.
I tried to set torch.no_grad
on the CNN when it is applied to the prototypes but this resulted in the prototypes not getting updated as well.
Does anyone know a method to perform this construct?
Here is a part of the code used in the experiment.
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.conv3 = nn.Conv2d(64, 64, 3, 1)
self.conv4 = nn.Conv2d(64, 128, 3, 1)
def forward(self, x):
x = F.relu(self.conv2(F.relu(self.conv1(x))))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv4(F.relu(self.conv3(x))))
x = F.max_pool2d(x, 2, 2)
return x
class Model(nn.Module):
def __init__(self, n_prototypes):
super(Model, self).__init__()
self.backbone = CNN()
self.prototype = nn.Parameter(
torch.rand((n_prototypes,) + (28, 28))
)
def forward(self, x):
x = self.backbone(x) # update CNN kernels
y = self.backbone(self.prototype) # update prototypes but not CNN kernels
return x, y
Thanks in advance.