Conditional weight updates


I’m working on an experiment where I train a model using CNN backbone which is applied on the input image and on a set of learnable prototypes. During the training process I want to update the weights of the CNN when it is applied on the image. When it is applied to the prototype I want to freeze the weights of the CNN but do update the prototypes.

I tried to set torch.no_grad on the CNN when it is applied to the prototypes but this resulted in the prototypes not getting updated as well.

Does anyone know a method to perform this construct?

Here is a part of the code used in the experiment.

import torch
import torch.nn as nn
import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()

        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.conv3 = nn.Conv2d(64, 64, 3, 1)
        self.conv4 = nn.Conv2d(64, 128, 3, 1)

    def forward(self, x):
        x = F.relu(self.conv2(F.relu(self.conv1(x))))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv4(F.relu(self.conv3(x))))
        x = F.max_pool2d(x, 2, 2)
        return x

class Model(nn.Module):
    def __init__(self, n_prototypes):
        super(Model, self).__init__()

        self.backbone = CNN()

        self.prototype = nn.Parameter(
            torch.rand((n_prototypes,) + (28, 28))

    def forward(self, x):

        x = self.backbone(x)  # update CNN kernels
        y = self.backbone(self.prototype)  # update prototypes but not CNN kernels

        return x, y

Thanks in advance.

You may consider to use different optimizers.
It is a way to train generative adversarial networks (training generator network once, discriminator netowork once and so on.)
for example,

model_1 = CNN()
model_2 = Prototype()
optimizer_1 = torch.optim.Adam(
optimizer_2 = torch.optim.Adam(

# during training

if prototype: 

Is this what you are looking for?

Besides, if you use torch.no_grad, I think it might disconnect gradient chains when loss.backward() is called.
You also can look into the gradients after loss.backward() is called.

print(model.backbone.conv1.grad) # print(model.some_layer.grad)

It gets in the right direction but x and y are also used in another module so I maybe also need a optimizer for that part, making it more complex.

Is there maybe a way to set the kernel gradients with respect to the prototypes to zero before they are applied?