I have a custom layer where I decorate the forward and backward functions with
@torch.inference_mode(). The layer receives some tensors where it places the output in the forward pass and grads in another tensor in the backward pass. My suspicion is that it hangs on to some intermediate tensors (which should be deleted as soon as the function ends) because the memory keeps going up until OOM and switching the layer with another layer solves the issue.
Can I somehow verify that this, in fact, is the case? So, basically, is there a way to find out all the tensors that a computational graph is hanging on to?