Hello. I’m trying to implement Distributed Batch sampler with 3180K audio dataset.
Defining Dataset, Sampler, Dataloader is quite fast.
But it took 20 minute for starting first step.
for inputs in train_dataloader:
. <<<< take 20 minute for outputing first mini batch(inputs).
training step works fast but for every epoch’s start, it takes same 20 min for starting first step.
The code I used is like below.
I don’t think creating dynamic batch or loading dataset pickle file will take up that much time.
Maybe shuffling took much time.
Please let me know if I’m using something wrong.
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
def train(
rank,
config,
port,
):
world_size = config.world_size
torch.manual_seed(seed)
torch.cuda.set_device(rank)
dist.init_process_group(
backend='nccl', init_method=f"tcp://localhost:{str(port)}",
rank=rank,
world_size=world_size
)
train_dataset = DatasetSorted(config)
collator = PadCollate()
train_sampler = DistributedSortedDynamicBatchSampler(
train_dataset, config, num_replicas=world_size,
rank=rank, shuffle=True
)
model = ASRModel(config, e_matrix).to(rank)
model = FSDP(model)
optimizer = torch.optim.Adam(
params=model.parameters(),
**config.base_optimizer_arguments
)
# Dataloader
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=(train_sampler is None),
num_workers=config.num_worker,
collate_fn=collator,
pin_memory=True,
batch_sampler=train_sampler
)
# for every Epoch
for epoch_ in range(config.epochs):
epoch = epoch_ + 1
train_dataloader.batch_sampler.set_epoch(epoch)
# [1] Took too much time !!!!
for inputs in train_dataloader:
step += 1
....
class DatasetSorted(torch.utils.data.Dataset):
def __init__(self, config):
self.config = config
# Sort dataset by wave length
self.data_df = pd.read_pickle(os.path.join(config.dataset_path, config.train_df_filename))
if config.do_limit_audio_length:
self.data_df = limit_length_in_dataset(self.data_df, 'wav', config)
if config.sort_dataset:
print('Sorting Dataset by Wave Length')
self.data_df = self.data_df.sort_values(by='wav_length', ascending=False).reset_index(drop=True)
else:
print('Dataset Already Sorted')
self.wav_lengths = [(idx, length) for idx, length in enumerate(self.data_df['wav_length'])]
self.sampling_rate = config.sampling_rate
with open(config.word_dict_path, 'r', encoding='utf8') as j:
self.word_dict = json.load(j)
# self.word_dict = word_dict
self.e_matrix = torch.load(config.e_matrix_path)
def __getitem__(self, idx):
text = [self.word_dict[mor] for mor in self.data_df['morphs'][idx] if mor in self.word_dict.keys()]
wav_path = self.data_df['wav_path'][idx].replace('@datapath', self.config.dataset_path)
wav = torch.from_numpy(np.load(wav_path, allow_pickle=True))
wav_length = torch.Tensor([[wav.shape[0]]])
assert ~torch.any(torch.isnan(wav)), f'wav input NaN {wav.shape} Path : {wav_path}'
return {
'text':text,
'wav' : wav,
'wav_length' : wav_length
}
def __len__(self):
return len(self.data_df)
class PadCollate():
def __init__(self, dim=0):
self.dim = dim
def _pad_tensor(self, vec, pad, dim, pad_value):
pad_size = list(vec.shape)
pad_size[dim] = pad - vec.size(dim)
padded_vec = torch.cat([vec, torch.ones(*pad_size) * pad_value], dim=dim)
return padded_vec
return padded_vec
def pad_collate(self, batch):
wav_lengths = torch.cat([x['wav_length'] for x in batch], dim=0).long()
wavs = torch.stack([self._pad_tensor(x['wav'], pad=int(torch.max(wav_lengths)), \
dim=self.dim, pad_value=0) for x in batch], dim=0).half()
texts = [x['text'] for x in batch]
return {
'wav': wavs, 'wav_lengths': wav_lengths.squeeze(-1).long(),
'text': texts,
}
def __call__(self, batch):
return self.pad_collate(batch)
class DistributedSortedDynamicBatchSampler(torch.utils.data.distributed.DistributedSampler):
def __init__(self, dataset, config, num_replicas=None, rank=None, shuffle=True):
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
self.lengths = dataset.wav_lengths
self.config = config
self.batches, self.num_samples_per_batch = self._create_dynamic_batch()
self.total_size = sum(self.num_samples_per_batch)
self.num_samples = self.total_size // self.num_replicas
def _create_dynamic_batch(self):
batches = []
num_samples_per_batch = []
batch = []
for idx, length in self.lengths:
if batch:
batch.append(idx)
if len(batch) == batch_size:
batches.append(batch)
num_samples_per_batch.append(len(batch))
batch = []
else:
batch_size = int(self.config.batch_second / (length / self.config.sampling_rate))
batch_size = batch_size - batch_size % self.num_replicas
# if 1 elements length(s) is larger than batch_second, it will be ignored
if batch_size == 0:
continue
batch.append(idx)
# last bucket
if batch:
batch = batch[:len(batch) - len(batch) % self.num_replicas]
batches.append(batch)
num_samples_per_batch.append(len(batch))
return batches, num_samples_per_batch
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
batch_ids = torch.randperm(len(self.batches), generator=g).tolist()
self.batches = [self.batches[i][self.rank::self.num_replicas] for i in batch_ids]
return iter(self.batches)
def __len__(self):
return self.num_samples