Currently I have this code, which runs a batch through the model and stores the result for later use. Before actual loss computation, the gradient with respect to the input needs to be computed:
x.requires_grad = True
x.retain_grad()
tmp = x
x = ml[i](x, start_p, mask)
x.retain_grad()
x.backward(torch.ones_like(x),retain_graph=True)
intermediate_gradients.append(tmp.grad)
x.grad = None
However running this backward pass will inadvertently accumulate the gradients of the weights into the .grad parameter of the network weights. Is there a way to avoid this last part? Using no_grad is not the solution, as later when the actual loss is computed, the gradients will need to be accumulated.