How to deal with excessive memory usages of PyTorch?

I don’t know how much memory your system has, but your training script would indeed take approx. 25GB RAM to store the:

  • input batch: 225 * 3 * 224 * 224 = 33868800 elements
  • model parameters: 138357544 elements
  • gradients: 138357544 elements (same as params)
  • internal running estimates of the optimizer: 138357544 elements
  • intermediate forward activations: 6452341200 elements

You can check the approximated memory requirement using:

# register forward hooks to check intermediate activation size
acts = []
for name, module in model.named_modules():
    if name == 'classifier' or name == 'features':
        continue
    module.register_forward_hook(lambda m, input, output: acts.append(output.detach()))

# execute single training step
X, y_true =  next(iter(tr_dataloader))
# Forward pass
y_hat = model(X) 
loss = criterion(y_hat, y_true)         
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()

# approximate memory requirements
model_param_size = sum([p.nelement() for p in model.parameters()])
grad_size = model_param_size
batch_size = 225 * 3 * 224 * 224
optimizer_size = sum([p.nelement() for p in optimizer.param_groups[0]['params']])
act_size = sum([a.nelement() for a in acts])

total_nb_elements = model_param_size + grad_size + batch_size + optimizer_size + act_size
total_gb = total_nb_elements * 4 / 1024**3
print(total_gb)
> 25.709281235933304

On my system the peak memory requirement comes close to the calculated estimation as 24.8GB.

1 Like