Hi all,
I am running into a consistent memory leak in the following projected gradient descent code (for vision applications) and can’t seem to figure out why. Over time, the memory usage of the program goes up to fixed levels, but does not increase at each iteration, only sporadically. I have included some comments in the code about what I have already tried to fix this issue.
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import copy
import torchvision
import torch.cuda.amp as amp
import torchvision.transforms as transforms
def gpu_mem_usage():
"""Computes the GPU memory usage for the current device (MB)."""
mem_usage_bytes = torch.cuda.max_memory_allocated()
return mem_usage_bytes / 1024 / 1024
class PGD(nn.Module):
def __init__(self, model, epsilon, num_steps, step_size, num_restarts, verbose=False):
super().__init__()
self.model = model
self.epsilon = epsilon
self.num_steps = num_steps
self.step_size = step_size
self.num_restarts = num_restarts
self.verbose = verbose
def perturb(self, x, y):
print("Mem allocated at beginning of perturb: " + str(gpu_mem_usage()))
x_hat = x.detach() + torch.zeros_like(x.detach()).uniform_(-self.epsilon, self.epsilon)
x_hat = torch.clamp(x_hat, 0, 1)
for i in range(self.num_steps):
print("Mem allocated at middle_1 of perturb: " + str(gpu_mem_usage()))
self.model.zero_grad(set_to_none = True)
x_hat.requires_grad_()
with torch.enable_grad():
print("Mem allocated at middle_2 of perturb: " + str(gpu_mem_usage()))
logits = self.model(x_hat)
print("Mem allocated at middle_3 of perturb: " + str(gpu_mem_usage()))
loss = F.cross_entropy(logits, y, reduction='sum')
print("Mem allocated at middle_4 of perturb: " + str(gpu_mem_usage()))
if self.verbose:
print("Average loss: {}".format(loss.item() / x.size(0)))
print("Mem allocated at middle_5 of perturb: " + str(gpu_mem_usage()))
grad = torch.autograd.grad(loss, x_hat, only_inputs=True)[0]
# have tried loss.backward() here instead
x_hat = x_hat + self.step_size * torch.sign(grad.detach())
x_hat = torch.min(torch.max(x_hat, x - self.epsilon), x + self.epsilon)
x_hat = torch.clamp(x_hat, 0, 1)
with torch.no_grad():
logits = self.model(x_hat)
fooled = logits.argmax(1) != y
print("Mem allocated at end of perturb: " + str(gpu_mem_usage()))
return x_hat.detach(), fooled.detach()
def forward(self, x, y):
B = x.size(0)
with torch.no_grad():
logits = self.model(x)
correct = logits.argmax(1) == y
x_hat = x.detach()
if self.verbose:
print("Starting new batch.")
for i in range(self.num_restarts):
if torch.sum(correct) == 0:
if self.verbose:
print("All images fooled, skipping remaining restarts.")
break
ind = torch.nonzero(correct).squeeze(1)
x_i, y_i = x[ind].detach(), y[ind].detach()
x_hat_i, fooled_i = self.perturb(x_i, y_i)
correct[ind[fooled_i]] = False
x_hat[ind[fooled_i]] = x_hat_i[fooled_i].detach()
if self.verbose:
print("Correct after restart {} : {} / {}".format(i, torch.sum(correct), B))
return x_hat.detach()
model = timm.models.create_model("resnet18",
num_classes=10,
pretrained=True)
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
loader = torch.utils.data.DataLoader(testset, batch_size=10,
shuffle=False, num_workers=2)
# Enable eval mode
model.eval()
model_wrapped = model
adversary = PGD(model_wrapped,
epsilon=8/255,
num_steps=40,
step_size=0.002,
num_restarts=5,
verbose=True).cuda() #mem leak persists even with num restarts = 1
for (inputs, labels) in loader:
# Transfer the data to the current GPU device
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
# Denormalize inputs for ease of clipping
inputs = F.interpolate(inputs, size = (224,224))
# Disable DDP synchronization
with amp.autocast(enabled=False):
# Perturb inputs
inputs = adversary(inputs, labels)
# Compute the predictions
preds = model_wrapped(inputs)
Any help or advice would be greatly appreciated. Thank you!