I noticed that if I do a single forward pass with batch size of 8 the memory used on GPU is around 20 GB and when I do 4 forward passes with batch size of 2, the memory used on GPU is 28 GB.
Essentially both runs have the same amount of data, so where is the memory difference coming from?
NOTE: I have cudnn set to disabled, in case that causes this?
Also, would you say these two methods of training would produce similar results?
Below is some dummy code to reproduce
Preliminary Code :
import torch
from torch import nn, optim
import torch.utils.data as data_utils
import torchvision.models as models
from torch.backends import cudnn
cudnn.enabled = False
device = torch.device("cuda")
torch.cuda.set_device(1)
train_data = torch.randn(320, 3, 640, 640)
train_labels = torch.ones(320).long()
train = data_utils.TensorDataset(train_data, train_labels)
train_loader = data_utils.DataLoader(train, batch_size=8, shuffle=True, pin_memory=True)
criterion = nn.CrossEntropyLoss().cuda()
model = models.densenet161().cuda()
model.train()
optimizer = optim.Adam(model.parameters(), lr=0.001)
Code for the case of multiple forward pass:
for x,y in train_loader:
x, y = x.to('cuda',non_blocking=True), y.to('cuda',non_blocking=True)
pred1 = model(x[:2,])
loss1 = criterion(pred1, y[:2,])
pred2 = model(x[2:4, ])
loss2 = criterion(pred2, y[2:4,])
pred3 = model(x[4:6, ])
loss3 = criterion(pred3, y[4:6,])
pred4 = model(x[6:8, ])
loss4 = criterion(pred4, y[6:8,])
loss = loss1 + loss2 + loss3 + loss4
optimizer.zero_grad()
loss.backward()
optimizer.step()
Code for the case of single forward pass:
for x,y in train_loader:
x, y = x.to('cuda',non_blocking=True), y.to('cuda',non_blocking=True)
pred = model(x)
loss1 = criterion(pred[:2,], y[:2,])
loss2 = criterion(pred[2:4,], y[2:4,])
loss3 = criterion(pred[4:6,], y[4:6])
loss4 = criterion(pred[6:8,], y[6:8,])
loss = loss1 + loss2 + loss3 + loss4
optimizer.zero_grad()
loss.backward()
optimizer.step()
P.S I understand the loss calculation in the code with single forward pass is weird but my actual code has different losses for different parts of the output so I wanted to keep that similarity in case I am doing something wrong.