Hi! I am currently working on a project where, for a given trained model, I perform inference batches of inputs and compute (and store) the gradients of the output with respect to the inputs.
My code is something like this
gradients_list=[]
for batch_of_inputs in batches:
batch_of_inputs.requires_grad_()
output=model(batch_of_inputs)
gradients = torch.autograd.grad(
outputs=output,
inputs=batch_of_inputs,
grad_outputs=torch.ones_like(output),
retain_graph=False,
)
gradients_list.append(gradients.detach_())
# here I need to do something to delete all the gradients
The thing is that the used memory increases and increases until OOM error rises. I have tried to use
del gradients,batch_of_inputs,output
and the problem persisted.
What would you suggest?
Thanks in advance