Rethinking memory estimates for training

Hello all and thanks for taking the time!

I know there are a lot of variations to roughly the same question: “how do I estimate GPU usage for training a given model?” Broadly, the answer is (according to pytorch-summary and pytorch-modelsize)

sum prod weights_sizes + 2 sum prod output_sizes + input_size

I understand that these can only ever be rough estimates. However, I’m struggling to see why it would not be a better estimate, assuming a fully sequential network, to try

2 sum prod trainable_weights_sizes + sum prod untrainable_weights_sizes + sum prod (input_sizes|trainable_weights) + 2 max (prod output_sizes, prod input_sizes)

With a breakdown as follows:

  1. 2 sum prod trainable_weights_sizes reflects both the trainable weight sizes and their gradients
  2. sum prod untrainable_weights_sizes reflects that untrained params (e.g. dropout masks) must be stored
  3. sum prod (input_sizes|trainable_weights) reflects that the input to a given layer is stored for a backward pass only when there are trainable weights. This is to calculate the partial derivatives of the weights with respect to the error signal and the input. Layers without trainable weights, such as activations or summing out dimensions, do not need to store inputs or outputs because the only partial derivatives they calculate are with respect to the error signal
  4. 2 max (prod output_sizes, prod input_sizes) reflects the short-term requirements of storing both the input and output of a result at the same time. It’s an upper bound.

Could anyone tell me the problem with my reasoning? I’m not stating, I’m asking. I think my estimate will tend to be lower than the original, but I’ve read that the original estimate already underestimates. I feel like I’m missing something here. In terms of the original estimate, I think not multiplying trainable weights by 2 is just an oversight, but I can’t figure out why the backward step would need to hold all outputs in memory.

Thanks again for your time,