Multiple forward passes, single conditional backward pass

Suppose forward passes are performed multiple times through a single model within a for loop, but only a single backward pass is called. Can PyTorch handle this scenario? Do the intermediate values of the forward pass get overwritten with each new forward pass, thus rendering the backward pass incorrect? Or does PyTorch manage all the forward passes in separate computational graphs?

An example of this scenario could look like this, where the model takes as input some random quantity and only the lowest resulting loss is optimized:


size = 3
losses = torch.zeros(size)
for h in torch.randn(size):
    out = model(x, h)
    single_loss = get_loss(out, target)
    losses[h] = single_loss
sort_losses, _ = torch.sort(losses) # Sort by loss, low to high

loss = sort_losses[0] # Optimize only the lowest loss

Thanks in advance for any help.

Your approach would calculate the gradients of the used model parameters w.r.t. the “lowest loss”, i.e. sort_losses[0].
Note however, that all computation graphs (with all intermediate activations) will be stored during the execution of the loop, which will of course increase the memory usage.

Also, if you are using e.g. batchnorm layers, their running stats will be updated in each forward pass, which might not fit your use case.

1 Like