Hi all,
I am training a GAN for time series prediction. During training, I sample n_samples predictions and backpropagate the L2-loss for the best sample.

for _ in range(n_samples):
out = net(x)
loss = lossfunc(x,y)
l2_loss = l2_loss.append(loss)
l2_loss.stack( l2_loss )
min_loss = torch.min( l2_loss )
min_loss.backward()

When I call min_loss.backward(), the time for calculating the backward pass grows linear with n_samples. However, the gradient for the non minimum losses should be 0 and not considered for optimisation. I only want to backpropagate the L2-loss for the best sample.
Therefore, I do not understand why the calculation time increases and does not stay constant? Does .backward() also calculate all gradients even if they are not used? Is there a way to make my code more efficient and detach all non minimum losses?
Thank you for your answer:)

As far as I know backward computes the gradients all the way back to the leave nodes. I mean, there’s no way for backward knowing which particular gradients you are interested in and when. Maybe the simplest way for computing the gradients you are interested in is doing it semi-automatically via the grad function:

from torch.autograd import grad
d_a_b = grad(a, b, retain_graph=True)
d_a_c = grad(a, c, retain_graph=True)
...

How about instead of using the function torch.min to obtain min_loss, compare the value of the loss and keep the min_loss in the loop. Then, at the end call .backward()on min_loss :

min_loss_value= 99.00
for _ in range(n_samples):
out = net(x)
loss = lossfunc(x,y)
if oss.detach() < min_loss_value:
min_loss = loss
min_loss_value = loss.detach()
min_loss.backward()

I think this way, you definitely have only one loss to compute its gradients. But I guess when you use torch.min() the gradients for others in the list may still be computed. So, basically, autograd computes the gradients on torch.max and torch.min. For example, if we have a tensor with elements [1, 2, 3], then torch.max() will retrieve the max value, but there is a tensor of gradients which [0, 0, 1]. These gradient values are like the effect of each element on the output of torch.max().