DistributedSampler memory consumption in PyTorch 0.4.0

Hello everyone,

I’m getting some problem that my memory consumption on dashboard looks so weird…

At runtime each memory consumption is about 8G, when training on single GPUs, single machine; And each memory consumption has increased from 8G to 32G, when I use multiple machines.

But each consumption has decreased when I use multiple machines and don’t use DistributedSampler.

I don’t know why it can make a big difference? Are there are any reasons can cause it and how I can fix it?

My code is following:

corpus_dataset = CorpusDataset(h5py_path, self.word2Vec, self.args.maxL_input, self.args.maxL_output)
 
train_sampler = None
if self.args.distributed:
    dist.init_process_group(backend=self.args.distBackend, init_method=self.args.distUrl,
                                    world_size=self.args.worldSize, rank=self.args.rank)
    train_sampler = distUtils.DistributedSampler(corpus_dataset, self.args.worldSize, self.args.rank)

custom_loader = Data.DataLoader(
    dataset=corpus_dataset,
    batch_size=self.args.batchSize,
    shuffle=(train_sampler is None),
    drop_last=(train_sampler is not None),
    num_workers=1,
    collate_fn=collate_fn,
    sampler=train_sampler
)
   

for epoch in range(self.args.numEpochs):                
    for posts, p_lens, responses, r_lens, labels in custom_loader:
	self.optimizer.zero_grad()
	score = self.dual_encoder(posts, p_lens, responses, r_lens)
 	loss = self.loss_fc(score, labels)
 	loss.backward()
	if self.args.distributed:
		self.average_gradients(self.dual_encoder)
	self.optimizer.step()
        pass

I get something. I rewrote line 41 of DistributedSampler.class

indices = list(torch.randperm(len(self.dataset), generator=g))

as follow:

indices = torch.randperm(len(self.dataset), generator=g).numpy().tolist()

It works for me and the memory consumption is maintained at a certain level.

1 Like

This commit fixed and be merged.