Suppose forward passes are performed multiple times through a single model within a for loop, but only a single backward pass is called. Can PyTorch handle this scenario? Do the intermediate values of the forward pass get overwritten with each new forward pass, thus rendering the backward pass incorrect? Or does PyTorch manage all the forward passes in separate computational graphs?
An example of this scenario could look like this, where the model takes as input some random quantity and only the lowest resulting loss is optimized:
optimizer.zero_grad() size = 3 losses = torch.zeros(size) for h in torch.randn(size): out = model(x, h) single_loss = get_loss(out, target) losses[h] = single_loss sort_losses, _ = torch.sort(losses) # Sort by loss, low to high loss = sort_losses # Optimize only the lowest loss loss.backward() optimizer.step()
Thanks in advance for any help.