Backwarding only respect with a part of network, at first

Hello everyone, I am struggling with the backward() function for this architecture I’m working with:

from losses import mmd_NTK
from models import FullyConnected, Convolutional

latent_size = 16
net = FullyConnected(latent_size=latent_size)

noise = torch.randn((noise_batch_size, latent_size))
input_images, input_labels = next(iter(train_loader))
generated_images = net(noise)

classifier = Convolutional()

loss = mmd_NTK(input_images, generated_images, classifier)

I’ve skipped many steps of the actual code, but I believe it’s clear what I’m doing. I will like to optimize the parameters of the network net backproping with respect to loss. But I have some problems before… In particular my mmd_NTK function is defined as follows

def mmd_NTK(batch_images, batch_generated_images, classifier):
    vec1 = get_NTK_avg_feature(batch_images, classifier)
    vec2 = get_NTK_avg_feature(batch_generated_images, classifier)
    return distance(vec1, vec2)

My loss is set to be equal to the distance of this two vectors, that are the derivatives of some output function with respect to the parameters of another network, classifier, which is already trained, and that I don’t want to train anymore.

def get_NTK_avg_feature(images, classifier):
    outputs = classifier(images)
    avg_outputs = outputs.mean(dim=0)
    squared_outputs = (avg_outputs ** 2).sum()  # This is to get a number as output
    parameters_grad = []
    for param in iter(classifier.parameters()):  # Now I fill a list with all the derivatives...
    return parameters_grad

Unfortunately when I run this a second time pytorch complains, in particular it tells that I’m trying to backward through the graph a second time. The problem arises because when I backward in the get_NTK_avg_feature, the first network net is also taken into account. I wouldn’t want this. But I also don’t want to completely erase the dependency of my loss (the distance between vec1 and vec2) with respect to net, since I will optimize it with respect its parameters. Any clue?

P.s. Also, it turns out that vec1 and vec2 are equal, or better, if I compute vec2 the value inside vec1 gets automatically overwritten with the value of vec2. This is an independent problem though…