Distributed training dataloader Takes too much time for every epoch

Hello. I’m trying to implement Distributed Batch sampler with 3180K audio dataset.
Defining Dataset, Sampler, Dataloader is quite fast.

But it took 20 minute for starting first step.
for inputs in train_dataloader: . <<<< take 20 minute for outputing first mini batch(inputs).

training step works fast but for every epoch’s start, it takes same 20 min for starting first step.

The code I used is like below.
I don’t think creating dynamic batch or loading dataset pickle file will take up that much time.
Maybe shuffling took much time.

Please let me know if I’m using something wrong.

from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP

def train(
    rank,
    config,
    port,
):
        
    world_size = config.world_size
    torch.manual_seed(seed)
    torch.cuda.set_device(rank)
    
    dist.init_process_group(
        backend='nccl', init_method=f"tcp://localhost:{str(port)}", 
        rank=rank, 
        world_size=world_size
    )

    train_dataset = DatasetSorted(config)
    collator = PadCollate()
    train_sampler = DistributedSortedDynamicBatchSampler(
        train_dataset, config, num_replicas=world_size, 
        rank=rank, shuffle=True
    )
    model = ASRModel(config, e_matrix).to(rank)
  
    model = FSDP(model)
  
    optimizer = torch.optim.Adam(
        params=model.parameters(),
        **config.base_optimizer_arguments
    )

    # Dataloader
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=(train_sampler is None),
        num_workers=config.num_worker, 
        collate_fn=collator,
        pin_memory=True,
        batch_sampler=train_sampler
    )
      
    # for every Epoch
    for epoch_ in range(config.epochs):
        epoch = epoch_ + 1
        train_dataloader.batch_sampler.set_epoch(epoch)

        # [1] Took too much time !!!!
        for inputs in train_dataloader:                                    
            step += 1
            ....
class DatasetSorted(torch.utils.data.Dataset):
    def __init__(self, config):
        
        self.config = config
        
        # Sort dataset by wave length
        
        self.data_df = pd.read_pickle(os.path.join(config.dataset_path, config.train_df_filename))
        
        if config.do_limit_audio_length:
            self.data_df = limit_length_in_dataset(self.data_df, 'wav', config)
        
        if config.sort_dataset:
            print('Sorting Dataset by Wave Length')
            self.data_df = self.data_df.sort_values(by='wav_length', ascending=False).reset_index(drop=True)
        else:
            print('Dataset Already Sorted')
        
        self.wav_lengths = [(idx, length) for idx, length in enumerate(self.data_df['wav_length'])]
        
        self.sampling_rate = config.sampling_rate
        
        with open(config.word_dict_path, 'r', encoding='utf8') as j:
            self.word_dict = json.load(j)
        
#         self.word_dict = word_dict
        
        self.e_matrix = torch.load(config.e_matrix_path)

                                
    def __getitem__(self, idx):

        text = [self.word_dict[mor] for mor in self.data_df['morphs'][idx] if mor in self.word_dict.keys()]
        
        wav_path = self.data_df['wav_path'][idx].replace('@datapath', self.config.dataset_path)
        wav = torch.from_numpy(np.load(wav_path, allow_pickle=True))
        wav_length = torch.Tensor([[wav.shape[0]]])
        
        assert ~torch.any(torch.isnan(wav)), f'wav input NaN {wav.shape} Path : {wav_path}'

        return {
            'text':text, 
            'wav' : wav,
            'wav_length' : wav_length
        }
    
    def __len__(self):
        return len(self.data_df)
    

class PadCollate():
    def __init__(self, dim=0):
        self.dim = dim
    
    def _pad_tensor(self, vec, pad, dim, pad_value):
        pad_size = list(vec.shape)
        pad_size[dim] = pad - vec.size(dim)
        
        padded_vec = torch.cat([vec, torch.ones(*pad_size) * pad_value], dim=dim)
        return padded_vec
        
        return padded_vec
    
    def pad_collate(self, batch):
        wav_lengths = torch.cat([x['wav_length'] for x in batch], dim=0).long()
        
        wavs = torch.stack([self._pad_tensor(x['wav'], pad=int(torch.max(wav_lengths)), \
                                             dim=self.dim, pad_value=0) for x in batch], dim=0).half()
        texts = [x['text'] for x in batch]
                
        return {
            'wav': wavs, 'wav_lengths': wav_lengths.squeeze(-1).long(), 
            'text': texts,
        }

        
    def __call__(self, batch):
        return self.pad_collate(batch)
    

class DistributedSortedDynamicBatchSampler(torch.utils.data.distributed.DistributedSampler):
    def __init__(self, dataset, config, num_replicas=None, rank=None, shuffle=True):
        super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
        
        self.lengths = dataset.wav_lengths
        self.config = config
        self.batches, self.num_samples_per_batch = self._create_dynamic_batch()
        self.total_size = sum(self.num_samples_per_batch)
        
        self.num_samples = self.total_size // self.num_replicas
  
    def _create_dynamic_batch(self):
        batches = []
        num_samples_per_batch = []
        batch = []
        for idx, length in self.lengths:
            if batch:
                batch.append(idx)
                if len(batch) == batch_size:
                    batches.append(batch)
                    num_samples_per_batch.append(len(batch))
                    batch = []
            else:
                batch_size = int(self.config.batch_second / (length / self.config.sampling_rate))
                batch_size = batch_size - batch_size % self.num_replicas
                # if 1 elements length(s) is larger than batch_second, it will be ignored
                if batch_size == 0:
                    continue
                batch.append(idx)
                
        # last bucket
        if batch:
            batch = batch[:len(batch) - len(batch) % self.num_replicas]
            batches.append(batch)
            num_samples_per_batch.append(len(batch))
                
        return batches, num_samples_per_batch
   
    def __iter__(self):
        # deterministically shuffle based on epoch
        g = torch.Generator()
        g.manual_seed(self.epoch)
  
        if self.shuffle:
            batch_ids = torch.randperm(len(self.batches), generator=g).tolist()
            self.batches = [self.batches[i][self.rank::self.num_replicas] for i in batch_ids]
  
        return iter(self.batches)
    
    def __len__(self):
        return self.num_samples

Thanks for asking. Looks like this is a dataloader related question, can @VitalyFedyunin kindly take a look?