DistributedSampler memory consumption in PyTorch 0.4.0


#1

Hello everyone,

I’m getting some problem that my memory consumption on dashboard looks so weird…

At runtime each memory consumption is about 8G, when training on single GPUs, single machine; And each memory consumption has increased from 8G to 32G, when I use multiple machines.

But each consumption has decreased when I use multiple machines and don’t use DistributedSampler.

I don’t know why it can make a big difference? Are there are any reasons can cause it and how I can fix it?

My code is following:

corpus_dataset = CorpusDataset(h5py_path, self.word2Vec, self.args.maxL_input, self.args.maxL_output)
 
train_sampler = None
if self.args.distributed:
    dist.init_process_group(backend=self.args.distBackend, init_method=self.args.distUrl,
                                    world_size=self.args.worldSize, rank=self.args.rank)
    train_sampler = distUtils.DistributedSampler(corpus_dataset, self.args.worldSize, self.args.rank)

custom_loader = Data.DataLoader(
    dataset=corpus_dataset,
    batch_size=self.args.batchSize,
    shuffle=(train_sampler is None),
    drop_last=(train_sampler is not None),
    num_workers=1,
    collate_fn=collate_fn,
    sampler=train_sampler
)
   

for epoch in range(self.args.numEpochs):                
    for posts, p_lens, responses, r_lens, labels in custom_loader:
	self.optimizer.zero_grad()
	score = self.dual_encoder(posts, p_lens, responses, r_lens)
 	loss = self.loss_fc(score, labels)
 	loss.backward()
	if self.args.distributed:
		self.average_gradients(self.dual_encoder)
	self.optimizer.step()
        pass

#2

I get something. I rewrote line 41 of DistributedSampler.class

indices = list(torch.randperm(len(self.dataset), generator=g))

as follow:

indices = torch.randperm(len(self.dataset), generator=g).numpy().tolist()

It works for me and the memory consumption is maintained at a certain level.


#3

This commit fixed and be merged.