Hello everyone, I am struggling with the backward() function for this architecture I’m working with:

```
from losses import mmd_NTK
from models import FullyConnected, Convolutional
latent_size = 16
net = FullyConnected(latent_size=latent_size)
noise = torch.randn((noise_batch_size, latent_size))
input_images, input_labels = next(iter(train_loader))
generated_images = net(noise)
classifier = Convolutional()
classifier.load_state_dict(torch.load('./trained_classifier.pth'))
loss = mmd_NTK(input_images, generated_images, classifier)
```

I’ve skipped many steps of the actual code, but I believe it’s clear what I’m doing. I will like to optimize the parameters of the network `net`

backproping with respect to loss. But I have some problems before… In particular my `mmd_NTK`

function is defined as follows

```
def mmd_NTK(batch_images, batch_generated_images, classifier):
vec1 = get_NTK_avg_feature(batch_images, classifier)
vec2 = get_NTK_avg_feature(batch_generated_images, classifier)
return distance(vec1, vec2)
```

My loss is set to be equal to the distance of this two vectors, that are the derivatives of some output function with respect to the parameters of another network, `classifier`

, which is already trained, and that I don’t want to train anymore.

```
def get_NTK_avg_feature(images, classifier):
outputs = classifier(images)
avg_outputs = outputs.mean(dim=0)
squared_outputs = (avg_outputs ** 2).sum() # This is to get a number as output
squared_outputs.backward()
parameters_grad = []
for param in iter(classifier.parameters()): # Now I fill a list with all the derivatives...
parameters_grad.append(param.grad)
return parameters_grad
```

Unfortunately when I run this a second time pytorch complains, in particular it tells that I’m trying to backward through the graph a second time. The problem arises because when I backward in the `get_NTK_avg_feature`

, the first network `net`

is also taken into account. I wouldn’t want this. But I also don’t want to completely erase the dependency of my loss (the distance between vec1 and vec2) with respect to `net`

, since I will optimize it with respect its parameters. Any clue?

P.s. Also, it turns out that vec1 and vec2 are equal, or better, if I compute vec2 the value inside vec1 gets automatically overwritten with the value of vec2. This is an independent problem though…