How to deal with excessive memory usages of PyTorch?

I am a beginner in PyTorch. For a project I am taking a VGG16 model (not pretrained), and training it from scratch. I have two seemingly identical code in keras and pytorch.

Keras Code:

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

trdata = ImageDataGenerator()
traindata = trdata.flow_from_directory(directory="Cat_Dog_data/train", target_size=(224,224))

model = tf.keras.applications.VGG16(

from tensorflow.keras.optimizers import Adam
opt = Adam(lr=0.001)
model.compile(optimizer=opt, loss=tf.keras.losses.categorical_crossentropy, metrics=['accuracy']), epochs=100, steps_per_epoch=100)

Pytorch Code:

from torchvision import transforms
from torchvision import datasets
from import DataLoader

from torch import nn, optim
from tqdm import tqdm

image_transform = transforms.Compose([transforms.Resize(size=(224,224)), transforms.ToTensor()])
tr_dataset = datasets.ImageFolder(root="Cat_Dog_data/train", transform=image_transform)
# I have 22500 images so to make steps_per_epoch 100 (same as keras code) batchsize if set to 225.
tr_dataloader = DataLoader(tr_dataset, batch_size=225, shuffle=True)

from torchvision import models

model = models.vgg16(pretrained=False)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for t in range(100):
    print(f'Epoch {t}/100')
    for X, y_true in tqdm(tr_dataloader):
        # Forward pass
        y_hat = model(X) 
        loss = criterion(y_hat, y_true)         
        # Backward pass

The keras one trains easily, takes arround 27min per epoch on CPU and have no dramatic effect on memory. But when I try to run the pytorch code it immediately takes up the entire memory, and the computer becomes extremely sluggish. I can’t even train the pytorch model on GPU as I have 4GB of GPU memory and the code fails with error demanding 6GB of memory.

Any suggestions on how to make this work? Are there extra settings that needs to be done in pytorch that I am missing? Thanks in advance.

I can’t speak to the CPU portion of this. But in order to get this working on a GPU – you could use half the batch size you are currently using – which should take up around 3GB and fit on your GPU. You could also reduce the image size. There are ways to ensure that a smaller batch size won’t reduce the models ability to learn – via gradient accumulation.

Yes I can do that. But the motive of the question wasn’t that. I want to know why for seemingly identical code pytorch is taking up so much more memory than its Keras counterpart. Like the difference is huge.

I don’t know how much memory your system has, but your training script would indeed take approx. 25GB RAM to store the:

  • input batch: 225 * 3 * 224 * 224 = 33868800 elements
  • model parameters: 138357544 elements
  • gradients: 138357544 elements (same as params)
  • internal running estimates of the optimizer: 138357544 elements
  • intermediate forward activations: 6452341200 elements

You can check the approximated memory requirement using:

# register forward hooks to check intermediate activation size
acts = []
for name, module in model.named_modules():
    if name == 'classifier' or name == 'features':
    module.register_forward_hook(lambda m, input, output: acts.append(output.detach()))

# execute single training step
X, y_true =  next(iter(tr_dataloader))
# Forward pass
y_hat = model(X) 
loss = criterion(y_hat, y_true)         
# Backward pass

# approximate memory requirements
model_param_size = sum([p.nelement() for p in model.parameters()])
grad_size = model_param_size
batch_size = 225 * 3 * 224 * 224
optimizer_size = sum([p.nelement() for p in optimizer.param_groups[0]['params']])
act_size = sum([a.nelement() for a in acts])

total_nb_elements = model_param_size + grad_size + batch_size + optimizer_size + act_size
total_gb = total_nb_elements * 4 / 1024**3
> 25.709281235933304

On my system the peak memory requirement comes close to the calculated estimation as 24.8GB.

1 Like

Hi, thanks for the information. Any tips on how to reduce the memory usages? Other than shortening batch size and accumulating gradient?

You could use torch.utils.checkpoint to trade compute for memory. While this would slow down your code you would also save memory.