Hello, all
I am new to Pytorch and I meet a strange GPU memory behavior while training a CNN model for semantic segmentation. Batchsize = 1, and there are totally 100 image-label pairs in trainset, thus 100 iterations per epoch. However the GPU memory consumption increases a lot at the first several iterations while training.
[Platform] GTX TITAN X (12G), CUDA-7.5, cuDNN-5.0
torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
Then GPU memory consumption is 2934M – 4413M – 4433M – 4537M – 4537M – 4537M at the first six iterations.
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
Then GPU memory consumption is 1686M – 1791M – 1791M – 1791M – 1791M – 1791M at the first six iterations.
Why GPU memory consumption increases while training, especially, increases so largely while no cuDNN? (In my opinion, GPU memory consumption won’t increase while the CNN has been build and starts training)
Does anyone meet the same problem? Or could anyone give some help?
This is the code snippet
def train(train_loader, model, criterion, optimizer, epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
# switch to train mode
model.train()
end = time.time()
for i, (input, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
target = target.long()
input = input.cuda(async=True)
target = target.cuda(async=True)
input_var = torch.autograd.Variable(input)
target_var = torch.autograd.Variable(target)
# compute output
output = model(input_var)
loss = criterion(output, target_var)
# record loss
losses.update(loss.data[0], input.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
epoch, i+1, len(train_loader),
batch_time=batch_time,
data_time=data_time,
loss=losses))