I have a model with two branches. Both branches share the embedding layer but then they diverge and each loss is calculated against different labels. At the end, I sum the two losses up and backward only the accumulated loss. However, without retailn_graph=True it throws an error asking for retain_graph to be set to True. However, when retaining the graph, the backward step is so slow that makes the training practically impossible. Any solution for this?
Quote likely you are doing something wrong elsewhere and store a non-leaf (like
.to() before storing).
As a rule of thumb
retain_grad=True is only for cases when you know exactly why you need it, not for when some error message tells you (maybe we should improve the message).
Thanks so much for this answer! I kept focusing on the retain_graph but in fact I have a deeper issue in the way I was handling input output of my LSTM.