I have been struggling with a memory leakage problem for a while and I cannot find what I am doing wrong. The code I am using is OpenNMT-py and the main change I am trying to implement is a modified training code that includes a sampling method, so I can apply the policy gradient theorem to NMT.
The problem is the following. If I include the policy gradient method in the loss and to .backward() over it, I get a GPU out of memory error after several batches:
Traceback (most recent call last):
File "/home/mpiccard/Data/Desktop/HAN_NMT/full_source/train.py", line 592, in <module>
main()
File "/home/mpiccard/Data/Desktop/HAN_NMT/full_source/train.py", line 584, in main
train_model(model, fields, optim, data_type, model_opt, opt.train_part, opt.batch_size)
File "/home/mpiccard/Data/Desktop/HAN_NMT/full_source/train.py", line 337, in train_model
train_stats = trainer.REINFORCE_train(train_iter, epoch, report_func, train_part, data=dataset)
File "/data/mpiccard/Desktop/HAN_NMT/full_source/onmt/Trainer.py", line 293, in REINFORCE_train
report_stats, normalization, train_part,data=data,thisBatch=batch,REINFORCE=True)
File "/data/mpiccard/Desktop/HAN_NMT/full_source/onmt/Trainer.py", line 437, in _gradient_accumulation
REINFORCE=REINFORCE, data=data, batch=batch)
File "/home/mpiccard/Data/Desktop/pyTorch_last/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
result = self.forward(*input, **kwargs)
File "/data/mpiccard/Desktop/HAN_NMT/full_source/onmt/Models.py", line 676, in forward
ret = self.REINFORCE_step(batch,data,context_index,src,lengths,enc_final,memory_bank,part)
File "/data/mpiccard/Desktop/HAN_NMT/full_source/onmt/Models.py", line 773, in REINFORCE_step
dec_out, _,_ = self.doc_context[1](cache[0], dec_out, cache[1], context, batch_i=batch_i)
File "/home/mpiccard/Data/Desktop/pyTorch_last/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
result = self.forward(*input, **kwargs)
File "/data/mpiccard/Desktop/HAN_NMT/full_source/onmt/modules/HierarchicalContext.py", line 186, in forward
query_word_norm, mask=context_word_mask, all_attn=True)
File "/home/mpiccard/Data/Desktop/pyTorch_last/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
result = self.forward(*input, **kwargs)
File "/data/mpiccard/Desktop/HAN_NMT/full_source/onmt/modules/MultiHeadedAttn.py", line 121, in forward
key_up = shape(self.linear_keys(key))
File "/home/mpiccard/Data/Desktop/pyTorch_last/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
result = self.forward(*input, **kwargs)
File "/home/mpiccard/Data/Desktop/pyTorch_last/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 55, in forward
return F.linear(input, self.weight, self.bias)
File "/home/mpiccard/Data/Desktop/pyTorch_last/lib/python3.6/site-packages/torch/nn/functional.py", line 1026, in linear
output = input.matmul(weight.t())
RuntimeError: CUDA error: out of memory
Process finished with exit code 1
However, if I don’t include the policy gradient output in the loss (so I don’t backpropagate over those tensors), I don’t have any memory error and everything works fine.
I have monitored the GPU memory consumption using different methods. When I use torch.cuda.memory_allocated(), I always obtain the same value, 1.7GB (the overall GPU memory is 16GB). But when I check the memory allocated with nvidia-smi, I see that the value, already from the beginning is much higher (~ 8GB) and that it keeps growing in each iteration until I run out of memory.
I have tried different ways of releasing memory:
loss.detach()
del loss
backward(retain_graph=False)
gc.collect()
torch.cuda.empty_cache()
But none of them have worked.
Hope someone can shed some light on this problem. Specially I don’t understand the discrepancy between torch.cuda.memory_allocated() and nvidia-smi. Why I am holding more memory every iteration and I am not able to free that space?
By the way, when I run things in the CPU I don’t face this problem. However, it may just be that I have more RAM memory there and I have not reached the critical point (as it goes very slowly).