Hi, I am training a simple decoder (5M params) for a voice recognition problem. My input data is in two separate files:
x_tensor.pt
: 1.1GB, shape [3577, 200, 384]
y_tensor.pt
: 30KB, shape [3577]
The simplest approach is the following:
num_classes = 10
decoder = ClassificationDecoder(
hidden_dim=384,
n_head=2,
n_layer=2,
num_classes=num_classes,
fingerprint_mode=True,
)
logging.info("Load data")
x_tensor = torch.load("x_tensor.pt")
y_tensor = torch.load("y_tensor.pt")
optimiser = torch.optim.Adam(decoder.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
epochs = 5
for j in range(epochs):
x_audio = x_tensor
expected_output = y_tensor
x_text = torch.ones(len(x_audio), 1, 384)
output = decoder(x_text, x_audio)
loss = loss_fn(output, expected_output)
logging.info(f"Epoch {j}: {loss:.2f}")
optimiser.zero_grad()
logging.info("do loss.backward")
loss.backward(retain_graph=False)
logging.info("do optimiser step")
optimiser.step()
logging.info("Done")
This works and takes at most 15 seconds per epoch.
Then, I tried to upgrade it to use datasets/dataloaders. The new code reads:
num_classes = 10
decoder = ClassificationDecoder(
hidden_dim=384,
n_head=2,
n_layer=2,
num_classes=num_classes,
fingerprint_mode=True,
)
logging.info("Load data")
x_tensor = torch.load("x_tensor.pt")
y_tensor = torch.load("y_tensor.pt")
dataset = TensorDataset(x_tensor, y_tensor)
data_loader = DataLoader(dataset, batch_size=500)
optimiser = torch.optim.Adam(decoder.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss(reduction="sum")
epochs = 5
for j in range(epochs):
optimiser.zero_grad()
for data in data_loader:
x_audio = data[0]
expected_output = data[1]
x_text = torch.ones(len(x_audio), 1, 384)
output = decoder(x_text, x_audio)
loss = loss_fn(output, expected_output)
loss.backward(retain_graph=False)
logging.info(f"Epoch {j}: {loss:.2f}")
optimiser.step()
This also works, but it is significantly slower, even though everything else stays the same. I’ve played around with the batch_size
parameter, but in this case 1 epoch takes at least 20 mins (over 60x slow down!) and most of the time is spent at the loss.backward()
step.
If anyone could help me out with the following questions, I’d be really grateful:
- Why is the second solution so much slower? Is there any way to improve the code to make it faster? Is it an inherent feature of datasets/dataloaders or am I using it wrong?
- In the future I will want to train on 10x or even 100x more data, which will certainly not fit in my memory (16GB RAM). What is the correct way of dealing with this? Similarly, if at some point I want to train on GPUs, what are the good practices of loading and moving to GPU memory? Any suggestions or references would be welcome