Pytorch appears to be crashing due to OOM prematurely?

The memory usage is model-dependent and often the majority of the memory is used by the forward activations, not the parameters or gradients.
E.g. in this post I’ve posted some stats about the used model and you can see that the activations use:

6452341200 / 138357544 ~= 47

times more memory than the parameters.