How to calculate the gradient of the loss of test data without increasing allocate CUDA memory

After training one model, I tried to use it calculate the gradient of loss of test data, but it kept increasing the allocated CUDA memory when going through test loader, finally CUDA out of memory.

def show_mem_usage(device=1):
    gpu_stats = gpustat.GPUStatCollection.new_query()
    item = gpu_stats.jsonify()["gpus"][device]
    print('Used/total: ' + "{}/{}".format(item["memory.used"], item["memory.total"]))

def compute_grad_testdata(model, test_loader):
      model.eval()
      batch = 1
      loss = 0
      show_mem_usage()
   
    for x_test, y_test in test_loader: 
        show_mem_usage()
        with torch.cuda.device(CUDA_VISIBLE_DEVICES):
            x_test = x_test.cuda()
            y_test = y_test.cuda()
        score, feature, _ = model(x_test)
        batch_loss = calc_loss(score, y_test)
        loss += batch_loss
        batch += 1
        show_mem_usage()
    model.zero_grad() 
    show_mem_usage()
    loss = loss/batch
    params = [p for p in model.parameters() if p.requires_grad]

    return list(grad(loss, params, create_graph=True))

Used/total: 3362/24268
Used/total: 3362/24268
Used/total: 3362/24268
Used/total: 3362/24268
Used/total: 3366/24268
Used/total: 3366/24268
Used/total: 3370/24268
Used/total: 3370/24268
Used/total: 3376/24268
Used/total: 3376/24268
Used/total: 3380/24268
Used/total: 3380/24268
Used/total: 3384/24268
Used/total: 3384/24268
Used/total: 3390/24268
Used/total: 3390/24268
Used/total: 3526/24268
Used/total: 3526/24268
Used/total: 3680/24268
Used/total: 3680/24268
Used/total: 3848/24268
Used/total: 3848/24268
Used/total: 4060/24268
Used/total: 4060/24268
Used/total: 4214/24268
Used/total: 4214/24268
Used/total: 4426/24268
Used/total: 4426/24268
Used/total: 4578/24268
Used/total: 4578/24268
Used/total: 4792/24268
Used/total: 4792/24268
Used/total: 4944/24268
Used/total: 4944/24268
Used/total: 5158/24268
Used/total: 5158/24268
Used/total: 5310/24268
Used/total: 5310/24268
Used/total: 5522/24268
Used/total: 5522/24268
Used/total: 5676/24268
Used/total: 5676/24268
Used/total: 5888/24268
Used/total: 5888/24268
Used/total: 6040/24268
Used/total: 6040/24268
Used/total: 6254/24268
Used/total: 6254/24268
Used/total: 6406/24268
Used/total: 6406/24268
Used/total: 6620/24268
Used/total: 6620/24268
Used/total: 6772/24268
Used/total: 6772/24268
Used/total: 6984/24268
Used/total: 6984/24268
Used/total: 7138/24268
Used/total: 7138/24268
Used/total: 7350/24268
Used/total: 7350/24268
Used/total: 7504/24268
Used/total: 7504/24268
Used/total: 7716/24268
Used/total: 7716/24268
Used/total: 7868/24268
Used/total: 7868/24268
Used/total: 8082/24268
Used/total: 8082/24268
Used/total: 8234/24268
Used/total: 8234/24268
Used/total: 8446/24268
Used/total: 8446/24268
Used/total: 8600/24268
Used/total: 8600/24268
Used/total: 8812/24268
Used/total: 8812/24268
Used/total: 8966/24268
Used/total: 8966/24268
Used/total: 9178/24268
Used/total: 9178/24268
Used/total: 9330/24268
Used/total: 9330/24268
Used/total: 9544/24268
Used/total: 9544/24268
Used/total: 9696/24268
Used/total: 9696/24268
Used/total: 9908/24268
Used/total: 9908/24268
Used/total: 10062/24268
Used/total: 10062/24268
Used/total: 10274/24268
Used/total: 10274/24268
Used/total: 10428/24268
Used/total: 10428/24268
Used/total: 10640/24268
Used/total: 10640/24268
Used/total: 10792/24268
Used/total: 10792/24268
Used/total: 11006/24268
Used/total: 11006/24268
Used/total: 11158/24268
Used/total: 11158/24268
Used/total: 11372/24268
Used/total: 11372/24268
Used/total: 11524/24268
Used/total: 11524/24268
Used/total: 11736/24268
Used/total: 11736/24268
Used/total: 11890/24268
Used/total: 11890/24268
Used/total: 12102/24268
Used/total: 12102/24268
Used/total: 12254/24268
Used/total: 12254/24268
Used/total: 12468/24268
Used/total: 12468/24268
Used/total: 12620/24268
Used/total: 12620/24268
Used/total: 12834/24268
Used/total: 12834/24268
Used/total: 12986/24268
Used/total: 12986/24268
Used/total: 13198/24268
Used/total: 13198/24268
Used/total: 13352/24268
Used/total: 13352/24268
Used/total: 13564/24268
Used/total: 13564/24268
Used/total: 13716/24268
Used/total: 13716/24268
Used/total: 13930/24268
Used/total: 13930/24268
Used/total: 14082/24268
Used/total: 14082/24268
Used/total: 14296/24268
Used/total: 14296/24268
Used/total: 14448/24268
Used/total: 14448/24268
Used/total: 14660/24268
Used/total: 14660/24268
Used/total: 14814/24268
Used/total: 14814/24268
Used/total: 15026/24268
Used/total: 15026/24268
Used/total: 15180/24268
Used/total: 15180/24268
Used/total: 15392/24268
Used/total: 15392/24268
Used/total: 15544/24268
Used/total: 15544/24268
Used/total: 15758/24268
Used/total: 15758/24268
Used/total: 15910/24268
Used/total: 15910/24268
Used/total: 16122/24268
Used/total: 16122/24268
Used/total: 16276/24268
Used/total: 16276/24268
Used/total: 16288/24268
Used/total: 16288/24268
........
RuntimeError: CUDA out of memory. Tried to allocate 32.00 MiB (GPU 0; 23.70 GiB total capacity; 22.07 GiB already allocated; 7.69 MiB free; 22.34 GiB reserved in total by PyTorch)

I know that usually for testing, we need use “with torch.no_grad()” such that it will set requires_grad=False for new Tensors, I have tried this, it is true that it wont use more memory, but if I do this in my case I can not use autograd.grad to calculate the gradient. So any solution for calculate gradient on test data without increasing allocating CUDA memory. Thanks.

You are storing the batch_loss tensor with the entire computation graph in these lines of code:

        batch_loss = calc_loss(score, y_test)
        loss += batch_loss

which would thus also increase the memory usage.
You could avoid this by calculating the gradients inside the loop which would thus free the computation graph with its intermediate activation tensors.

Thanks. But how can I calculate the gradient inside the loop, mathematically I need a mean loss on all test data points (not a batch of them) to calculate the gradient on test data, so if I put grad(loss, params) inside the loop, it computes the gradients for current batch, right?

You could still scale the gradients while accumulating them inside the loop. Something like the second approach here could work:

model = models.resnet18()
criterion = nn.CrossEntropyLoss()

x = torch.randn(100, 3, 224, 224)
y = torch.randint(0, 1000, (100,))

# 1
loss = 0.
for x_, y_ in zip(x, y):
    x_ = x_.unsqueeze(0)
    y_ = y_.unsqueeze(0)
    output = model(x_)
    batch_loss = criterion(output, y_)
    loss += batch_loss
             
loss = loss/len(x)
params = [p for p in model.parameters() if p.requires_grad]
grad_ref = list(torch.autograd.grad(loss, params))


# 2
grad = None
params = [p for p in model.parameters() if p.requires_grad]
for x_, y_ in zip(x, y):
    x_ = x_.unsqueeze(0)
    y_ = y_.unsqueeze(0)
    output = model(x_)
    batch_loss = criterion(output, y_)
    if grad is None:
        grad = [g/len(x) for g in list(torch.autograd.grad(batch_loss, params))]
    else:
        tmp = list(torch.autograd.grad(batch_loss, params))
        grad = [g + t/len(x) for g, t in zip(grad, tmp)]

for g, g_ref in zip(grad, grad_ref):
    print((g - g_ref).abs().max())

You could also scale the batch_loss by len(x) alternatively.

Thanks a lot, ptrblck. But the memory usage still increased crazy even with the second method(maybe the problem is that I set create_graph=True in autograd), I just wondering why the memory usage do not increase when the gradient calculation is inside the loop? And also when do I need create_graph=True?

Most likely you don’t need create_graph=True as your code is also running without it. This option would be needed if you want to e.g. calculate the second derivative and I’ve also removed it.
The memory usage does not increase if you are calculating the gradients inside the loop since the computation graph is freed after each torch.autograd.grad call. I.e. the intermediate activations needed to compute the gradients are deleted. The gradients will then be accumulated into the leaf parameters and no additional tensors are needed to be kept around.