Consistent Memory Leak Project Gradient Descent

Hi all,

I am running into a consistent memory leak in the following projected gradient descent code (for vision applications) and can’t seem to figure out why. Over time, the memory usage of the program goes up to fixed levels, but does not increase at each iteration, only sporadically. I have included some comments in the code about what I have already tried to fix this issue.

import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import copy
import torchvision
import torch.cuda.amp as amp
import torchvision.transforms as transforms

def gpu_mem_usage():
    """Computes the GPU memory usage for the current device (MB)."""
    mem_usage_bytes = torch.cuda.max_memory_allocated()
    return mem_usage_bytes / 1024 / 1024

class PGD(nn.Module):
    def __init__(self, model, epsilon, num_steps, step_size, num_restarts, verbose=False):
        self.model = model
        self.epsilon = epsilon
        self.num_steps = num_steps
        self.step_size = step_size
        self.num_restarts = num_restarts
        self.verbose = verbose

    def perturb(self, x, y):
        print("Mem allocated at beginning of perturb: " + str(gpu_mem_usage()))
        x_hat = x.detach() + torch.zeros_like(x.detach()).uniform_(-self.epsilon, self.epsilon)
        x_hat = torch.clamp(x_hat, 0, 1)

        for i in range(self.num_steps):
            print("Mem allocated at middle_1 of perturb: " + str(gpu_mem_usage()))

            self.model.zero_grad(set_to_none = True)

            with torch.enable_grad():
                print("Mem allocated at middle_2 of perturb: " + str(gpu_mem_usage()))
                logits = self.model(x_hat)
                print("Mem allocated at middle_3 of perturb: " + str(gpu_mem_usage()))
                loss = F.cross_entropy(logits, y, reduction='sum')
                print("Mem allocated at middle_4 of perturb: " + str(gpu_mem_usage()))

                if self.verbose:
                    print("Average loss: {}".format(loss.item() / x.size(0)))
                print("Mem allocated at middle_5 of perturb: " + str(gpu_mem_usage()))

            grad = torch.autograd.grad(loss, x_hat, only_inputs=True)[0]
            # have tried loss.backward() here instead
            x_hat = x_hat + self.step_size * torch.sign(grad.detach())
            x_hat = torch.min(torch.max(x_hat, x - self.epsilon), x + self.epsilon)
            x_hat = torch.clamp(x_hat, 0, 1)

        with torch.no_grad():
            logits = self.model(x_hat)
            fooled = logits.argmax(1) != y
        print("Mem allocated at end of perturb: " + str(gpu_mem_usage()))
        return x_hat.detach(), fooled.detach()

    def forward(self, x, y):
        B = x.size(0)

        with torch.no_grad():
            logits = self.model(x)

        correct = logits.argmax(1) == y
        x_hat = x.detach()

        if self.verbose:
            print("Starting new batch.")

        for i in range(self.num_restarts):
            if torch.sum(correct) == 0:
                if self.verbose:
                    print("All images fooled, skipping remaining restarts.")
            ind = torch.nonzero(correct).squeeze(1)
            x_i, y_i = x[ind].detach(), y[ind].detach()
            x_hat_i, fooled_i = self.perturb(x_i, y_i)
            correct[ind[fooled_i]] = False
            x_hat[ind[fooled_i]] = x_hat_i[fooled_i].detach()
            if self.verbose:
                print("Correct after restart {} : {} / {}".format(i, torch.sum(correct), B))

        return x_hat.detach()

model = timm.models.create_model("resnet18",
transform = transforms.Compose(
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
loader =, batch_size=10,
                                         shuffle=False, num_workers=2)
# Enable eval mode
model_wrapped = model
adversary = PGD(model_wrapped,
                    verbose=True).cuda() #mem leak persists even with num restarts = 1

for (inputs, labels) in loader:
    # Transfer the data to the current GPU device
    inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
    # Denormalize inputs for ease of clipping
    inputs = F.interpolate(inputs, size = (224,224))
    # Disable DDP synchronization
    with amp.autocast(enabled=False):
        # Perturb inputs
        inputs = adversary(inputs, labels)
        # Compute the predictions
        preds = model_wrapped(inputs)

Any help or advice would be greatly appreciated. Thank you!