I am currently playing around distributed training, and I have some trouble trying to understand it:
When using torch.utils.data.distributed.DistributedSampler, does it mean to split the whole dataset D to D1 and D2(considering the one node two gpus scenario), and each process only iterates data from one of D1 or D2, thus there will never be a batch generated from a mixed of D1 data and D2 data? Or will the split operation be committed automatically each time one epoch is done and another epoch is started?
I have a similar question here. If this is the way the DistributedSampler sample data, and if the portion of the data is quite smal on each single GPU, the network on each single GPU may overfit trained on the small portion of the dataset.
Will the portion of the dataset on each single GPU update after each epoch?
Hi, I have found that your question can be solved by sampler.set_epoch(e). Just try my code, each gpu will get shuffled and different data in a different epoch.
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
torch.distributed.init_process_group(backend="nccl")
input_size = 5
output_size = 2
batch_size = 2
data_size = 16
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
class RandomDataset(Dataset):
def __init__(self, size, length, local_rank):
self.len = length
self.data = torch.stack([torch.ones(5), torch.ones(5)*2,
torch.ones(5)*3,torch.ones(5)*4,
torch.ones(5)*5,torch.ones(5)*6,
torch.ones(5)*7,torch.ones(5)*8,
torch.ones(5)*9, torch.ones(5)*10,
torch.ones(5)*11,torch.ones(5)*12,
torch.ones(5)*13,torch.ones(5)*14,
torch.ones(5)*15,torch.ones(5)*16]).to('cuda')
self.local_rank = local_rank
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
dataset = RandomDataset(input_size, data_size, local_rank)
sampler = DistributedSampler(dataset)
rand_loader = DataLoader(dataset=dataset,
batch_size=batch_size,
sampler=sampler)
e = 0
while e < 2:
t = 0
sampler.set_epoch(e)
for data in rand_loader:
print(data)
e+=1