Pytorch appears to be crashing due to OOM prematurely?

nelement() returns the number of elements of a tensor. If you want to check intermediate forward activation sizes you could check them separately as described e.g. in this post.