Hi, I’m working on very-large dataset, which has 600,000,000 data points.
Perhaps because of the data size, garbage collection (gc.collect()
) takes long time - about 10s. It made my training far slower.
I profiled my code and found that the DataParalell()
module was causing problem. It is reproduced with very simple code:
model = SimpleCNN()
model = torch.nn.DataParallel(model).cuda()
def train(epoch):
model.train()
cc = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data[0]))
collected = gc.collect()
cc += collected
print "collected", collected
print "collected", cc
train(0)
Results:
Train Epoch: 0 [0/60000 (0%)] Loss: 2.306420
collected 125
collected 118
collected 118
collected 118
collected 118
collected 118
...
But the number of collected garbages went to zero with removing DataParallel()
module:
model = SimpleCNN()
# model = torch.nn.DataParallel(model)
model = model.cuda()
train(0)
Results:
Train Epoch: 0 [0/60000 (0%)] Loss: 2.307985
collected 56
collected 0
collected 0
collected 0
collected 0
collected 0
collected 0
...
I want to remove garbages in DataParallel
. It will make my training faster.
How can I solve this problem?