Hi. My problem is the speed of HDF5 data loading and in the rest I will explain the problem and background.
I’ve recently used Pytorch’s Dataloader to load huge data to train neural networks. (33.33 GB data containing log amplitude of STFT audio files). As my tensor shape is huge ( batch_size, 625,513), I have to keep the batch size at most at 4, and use gradient accumulator. The small batch size leads to a lot of steps for training in each epoch (2167 steps for 8667 samples per epoch).
My problem is that, in every step of each epoch, it takes about 2 minutes to get the data, the other parts of each step is in few mil seconds or so. As I have so many steps per epoch, I cannot train the network now!
So, If you could help me, I would be really appreciate that.
Here is my data loading part and training loop:
import torch
from torch.utils.data import DataLoader
class TorchGenerator(Dataset):
# Constructor
def __init__(self, x, y):
self.x = x
self.y = y
def __len__(self):
return self.x.shape[0]
# Getter
def __getitem__(self, idx):
samplex = torch.tensor(self.x[idx], dtype= torch.float , device='cuda')
sampley = torch.tensor(self.y[idx], dtype= torch.float , device='cuda')
return samplex, sampley
data_f = h5py.File(path, "r")
x_train = data_f["X_train_arr"]
y_train = data_f["Y_train_arr"]
training_data = TorchGenerator(x_train,y_train)
train_dataloader = DataLoader(training_data, batch_size=BATCH_SIZE, num_workers=8, pin_memory=True)
def train(train_dataloader, STEPS, model, loss_fn, optimizer):
counter = 1
model.train()
optimizer.zero_grad()
lossbatch = [ ]
for x, y in train_dataloader:
x = x.cuda()
y = y.cuda()
predict = model(x)
y_t = y[:, :, :, 1].squeeze(dim=-1).long()
loss = loss_fn(predict, y_t)
loss.backward()
if counter % STEPS == 0:
optimizer.step()
optimizer.zero_grad()
loss_item = loss.detach()
lossbatch.append(loss_item)
counter += 1
loss_t = torch.mean(torch.stack(lossbatch))
return loss_t
P.S.1 : I put time.time() in every single step and I realized the bottleneck is the data loading part (the very beginning of each step of an epoch)
P.S2: As you see I generated the data with 8 workers, but it didn;t improve the speed that much.