Memory leak while training model generated by torch.fx.symbolic_trace() in data parallel mode

There is memory leak when I train a model in data parallel mode if this model is generated by torch.fx.symbolic_trace().
There seems to be no memory leak if I turn off one of the follow options:

  • data parallel mode (torch.nn.DataParallel)
  • This model is torch.fx.GraphModule created via torch.fx.Tracer

The python script that can reproduce this BUG is attached here. This program would leak over 1GB
memory per 10 training epochs.

import torch
import torch.nn as nn
import torch.fx as fx
import torchvision
import torchvision.transforms as transforms

batch_size = 64

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()

        self.backbone = torchvision.models.densenet121()
        self.linear = nn.Linear(1000, 10)

    def forward(self, x):
        out = self.backbone(x)
        out = self.linear(out)

        return out

train_dataset = torchvision.datasets.CIFAR10(root='data/',
                                                train=True,
                                                transform=transforms.ToTensor(),
                                                download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                            batch_size=batch_size, 
                                            shuffle=True)

model = MyNet()
model = fx.symbolic_trace(model)

model = model.to('cuda')
model = nn.DataParallel(model)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
                model.parameters(),
                lr=0.01,
                momentum=0.8,
                nesterov=True,
                weight_decay=0.0001)

for epoch in range(1000):
    print('epoch:', epoch)
    for i, (images, labels) in enumerate(train_loader):

        images = images.to('cuda')
        labels = labels.to('cuda')
        
        outputs = model(images)

        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        _, predicted = torch.max(outputs.reshape(-1, 10).data, 1)
        correctness = (predicted == labels).sum().item() / batch_size

        print('\ttraining loss: {:3.4f}'.format(loss.item()))
        print('\ttraining correctness: {:3.4f}'.format(correctness))

PyTorch: 1.10.0

System: ubuntu 18.04

CUDA: 10.2