So I had this problem for a few days that is driving me crazy. I have a seq2seq model (specifically a Listen, Attend and Spell model from google) that Im training and always goes crazy wild in gpu memory usage when reaching mini batch number 10. I have tried shuffling the dataloader to make sure the problem wasnt related to the specific data being retrieved in the 10th time but the error persisted to happen in this 10th iteration.
My train loop looks as follows:
for epoch in range(start_epoch, epochs):
epoch_step = 0
train_loss = []
train_ler = []
batch_loss = 0
for i, (data) in enumerate(train_loader):
print(
f"Current Epoch: {epoch} Loss {np.round(batch_loss, 3)} | Epoch step: {epoch_step}/{len(train_loader)}",
end="\r",
flush=True,
)
# Adjust LR
tf_rate = tf_rate_upperbound - (tf_rate_upperbound - tf_rate_lowerbound) * min(
(float(global_step) / tf_decay_step), 1
)
with torch.no_grad():
inputs = data[1]["inputs"].cuda()
labels = data[2]["targets"].cuda()
# minibatch execution
batch_loss, batch_ler = batch_iterator(
batch_data=inputs,
batch_label=labels,
las_model=las,
optimizer=optimizer,
tf_rate=tf_rate,
is_training=True,
max_label_len=params["data"]["vocab_size"],
label_smoothing=params["training"]["label_smoothing"],
)
del inputs
del labels
torch.cuda.empty_cache()
train_loss.append(batch_loss)
train_ler.extend(batch_ler)
global_step += 1
epoch_step += 1
Where batch_iterator
trains the network from the minibatch. Batch_iterator looks as follows:
def batch_iterator(
batch_data,
batch_label,
las_model,
optimizer,
tf_rate,
is_training,
max_label_len,
label_smoothing,
use_gpu=True,
):
label_smoothing = label_smoothing
max_label_len = min([batch_label.size()[1], max_label_len])
criterion = nn.NLLLoss(ignore_index=0).cuda()
optimizer.zero_grad()
raw_pred_seq, _ = las_model(
batch_data=batch_data,
batch_label=batch_label,
teacher_force_rate=tf_rate,
is_training=is_training,
)
pred_y = (
torch.cat([torch.unsqueeze(each_y, 1) for each_y in raw_pred_seq], 1)[:, :max_label_len, :]
).contiguous()
if label_smoothing == 0.0 or not (is_training):
pred_y = pred_y.permute(0, 2, 1) # pred_y.contiguous().view(-1,output_class_dim)
true_y = torch.max(batch_label, dim=2)[1][:, :max_label_len].contiguous() # .view(-1)
loss = criterion(pred_y, true_y)
# variable -> numpy before sending into LER calculator
batch_ler = LetterErrorRate(
torch.max(pred_y.permute(0, 2, 1), dim=2)[1]
.cpu()
.numpy(), # .reshape(current_batch_size,max_label_len),
true_y.cpu().data.numpy(),
) # .reshape(current_batch_size,max_label_len), data)
else:
true_y = batch_label[:, :max_label_len, :].contiguous()
true_y = true_y.type(torch.cuda.FloatTensor) if use_gpu else true_y.type(torch.FloatTensor)
loss = label_smoothing_loss(pred_y, true_y, label_smoothing=label_smoothing)
batch_ler = LetterErrorRate(
torch.max(pred_y, dim=2)[1].cpu().numpy(), # .reshape(current_batch_size,max_label_len),
torch.max(true_y, dim=2)[1].cpu().data.numpy(),
) # .reshape(current_batch_size,max_label_len), data)
if is_training:
loss.backward()
optimizer.step()
batch_loss = loss.cpu().data.numpy()
return batch_loss, batch_ler
The full code of the project is inside this github repository https://github.com/jiwidi/las-pytorch if want to check it:
While running in a v100 (16gb of gpu ram) the usage of memory is constant around 8gb and spikes to 14gb in 10th iteration. As the error persists to happen in this exact iteration regardless that I shuffle the dataloader I believe the error is in my code.
I also tried to put print statmentes before each batch_iterator where I would print the size of both inputs
and labels
, the size of them were equal to the ones in the previous minibatches. Output here:
For minibatch 0 inputs has size 2.048mb and labels has size 0.706944mb
For minibatch 1 inputs has size 2.08896mb and labels has size 0.706944mb
For minibatch 2 inputs has size 2.00704mb and labels has size 0.706944mb
For minibatch 3 inputs has size 2.12992mb and labels has size 0.706944mb
For minibatch 4 inputs has size 2.17088mb and labels has size 0.706944mb
For minibatch 5 inputs has size 2.048mb and labels has size 0.706944mb
For minibatch 6 inputs has size 2.048mb and labels has size 0.706944mb
For minibatch 7 inputs has size 2.048mb and labels has size 0.706944mb
For minibatch 8 inputs has size 1.92512mb and labels has size 0.706944mb
For minibatch 9 inputs has size 2.048mb and labels has size 0.706944mb
For minibatch 10 inputs has size 1.96608mb and labels has size 0.706944mb
The the actual error pytorch gives me:
raceback (most recent call last):
File "train.py", line 138, in <module>
label_smoothing=params["training"]["label_smoothing"],
File "/home/fhjaime966/las-pytorch/solver/solver.py", line 63, in batch_iterator
is_training=is_training,
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/home/fhjaime966/las-pytorch/model/las_model.py", line 34, in forward
listener_feature, ground_truth=batch_label, teacher_force_rate=teacher_force_rate
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/home/fhjaime966/las-pytorch/model/las_model.py", line 212, in forward
rnn_input, hidden_state, listener_feature
File "/home/fhjaime966/las-pytorch/model/las_model.py", line 180, in forward_step
attention_score, context = self.attention(rnn_output, listener_feature)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/home/fhjaime966/las-pytorch/model/las_model.py", line 296, in forward
* attention_score[0].unsqueeze(2).repeat(1, 1, listener_feature.size(2)),
RuntimeError: CUDA out of memory. Tried to allocate 12.00 MiB (GPU 0; 15.75 GiB total capacity; 14.69 GiB already allocated; 6.88 MiB free; 14.69 GiB reserved in total by PyTorch)
I know this is quite a big problem to debug as you would need to check the code at the repository https://github.com/jiwidi/las-pytorch to get an understanding of what is running so any help or tips with debugging gpu mem usage in pytorch are greatly appreciated.