There is memory leak when I train a model in data parallel mode if this model is generated by torch.fx.symbolic_trace().
There seems to be no memory leak if I turn off one of the follow options:
- data parallel mode (
torch.nn.DataParallel
) - This model is
torch.fx.GraphModule
created viatorch.fx.Tracer
The python script that can reproduce this BUG is attached here. This program would leak over 1GB
memory per 10 training epochs.
import torch
import torch.nn as nn
import torch.fx as fx
import torchvision
import torchvision.transforms as transforms
batch_size = 64
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.backbone = torchvision.models.densenet121()
self.linear = nn.Linear(1000, 10)
def forward(self, x):
out = self.backbone(x)
out = self.linear(out)
return out
train_dataset = torchvision.datasets.CIFAR10(root='data/',
train=True,
transform=transforms.ToTensor(),
download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
model = MyNet()
model = fx.symbolic_trace(model)
model = model.to('cuda')
model = nn.DataParallel(model)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
model.parameters(),
lr=0.01,
momentum=0.8,
nesterov=True,
weight_decay=0.0001)
for epoch in range(1000):
print('epoch:', epoch)
for i, (images, labels) in enumerate(train_loader):
images = images.to('cuda')
labels = labels.to('cuda')
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
_, predicted = torch.max(outputs.reshape(-1, 10).data, 1)
correctness = (predicted == labels).sum().item() / batch_size
print('\ttraining loss: {:3.4f}'.format(loss.item()))
print('\ttraining correctness: {:3.4f}'.format(correctness))