Pytorch appears to be crashing due to OOM prematurely?

@ ptrblck is there a better way than out.nelement() to find the real activation size of a network? After all, consider for example a UNET, the out.nelement() only shows the number of elements in the output tensor and not in all tensors along the way. I want a way to estimate how increasing the batch size increases the memory consumption, and would like to know the real size of activations during forward pass.