DataParallel() makes garbages

Hi, I’m working on very-large dataset, which has 600,000,000 data points.
Perhaps because of the data size, garbage collection (gc.collect()) takes long time - about 10s. It made my training far slower.

I profiled my code and found that the DataParalell() module was causing problem. It is reproduced with very simple code:

model = SimpleCNN()
model = torch.nn.DataParallel(model).cuda()

def train(epoch):
    model.train()
    cc = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset), 
                100. * batch_idx / len(train_loader), loss.data[0]))
        
        collected = gc.collect()
        cc += collected
        print "collected", collected
        
    print "collected", cc

train(0)

Results:

Train Epoch: 0 [0/60000 (0%)]	Loss: 2.306420
collected 125
collected 118
collected 118
collected 118
collected 118
collected 118

...

But the number of collected garbages went to zero with removing DataParallel() module:

model = SimpleCNN()
# model = torch.nn.DataParallel(model)
model = model.cuda()

train(0)

Results:

Train Epoch: 0 [0/60000 (0%)]	Loss: 2.307985
collected 56
collected 0
collected 0
collected 0
collected 0
collected 0
collected 0

...

I want to remove garbages in DataParallel. It will make my training faster.
How can I solve this problem?

why do you even need to gc.collect? manually calling it each layer will slow things down rather than increase performance.

I called gc.collect manually just for testing and for clarifying the problem. Original code does not call gc.collect manually.

The important thing is that garbage is created. Even if gc.collect is not called manually, because garbage continues to be created, gc.collect is called periodically (perhaps when memory is insufficient), which makes training slow down.