Pytorch appears to be crashing due to OOM prematurely?

It really depends on the model architecture and especially for e.g. conv layers, you would see a huge memory difference, while linear layers could yield the inverse effect.
Here is a smaller example:

# conv
model = nn.Conv2d(3, 64, 3, 1, 1)
x = torch.randn(1, 3, 224, 224)

out = model(x)

model_param_size = sum([p.nelement() for p in model.parameters()])
input_size = x.nelement()
act_size = out.nelement()

print('model size: {}\ninput size: {}\nactivation size: {}'.format(
    model_param_size, input_size, act_size))

> model size: 1792
  input size: 150528
  activation size: 3211264
  
# linear
model = nn.Linear(1024, 1024)
x = torch.randn(1, 1024)

out = model(x)

model_param_size = sum([p.nelement() for p in model.parameters()])
input_size = x.nelement()
act_size = out.nelement()

print('model size: {}\ninput size: {}\nactivation size: {}'.format(
    model_param_size, input_size, act_size))

> model size: 1049600
  input size: 1024
  activation size: 1024
2 Likes