Efficiently build decoder output tensor

I’m working on a sequence-to-sequence model with attention. I have a loop in the decoder where the output tensor is iteratively built, one timestep at a time. I’ve tried several methods for making this loop more efficient but I’m having some trouble and I wanted to ask the community to see if there was a better way to do this.

Method 1 often causes an OOM error - I think due to the fact that the entire computation graph is repeatedly copied in the list.
Method 2 is by far the fastest, but the graph is lost and therefore the model obviously fails to train.
Method 3 works but is very slow.

Is there some way to “temporarily” detach from the graph so that the outputs can be stored in the list outside the loop but then reattach without losing the autograd history? Or perhaps there is some other more efficient method that I’m missing?

Method 1 (store predictions in a list then stack)

preds_list = []
probs_list = []
for timestep in range(max_len):
    decoder_out, decoder_h, attn_weights = self.decoder(dec_x, h, src)
    top_value, top_index = torch.max(decoder_out, dim=1, keepdim=True)
    preds_list.append(top_index.detach())
    probs_list.append(decoder_out)
preds = torch.stack(preds_list, dim=1)
probs = torch.stack(probs_list, dim=1)

return preds, probs

Method 2 (retain detached predictions in a list then stack)

preds_list = []
probs_list = []
for timestep in range(max_len):
    decoder_out, decoder_h, attn_weights = self.decoder(dec_x, h, src)
    top_value, top_index = torch.max(decoder_out, dim=1, keepdim=True)
    preds_list.append(top_index.detach())
    probs_list.append(decoder_out.detach())
preds = torch.stack(preds_list, dim=1)
probs = torch.stack(probs_list, dim=1).requires_grad_()

return preds, probs

Method 3 (pre-allocate tensor and set timestep indices)

probs = torch.zeros(batch_size, max_len, self.output_size)
for timestep in range(max_len):
    decoder_out, decoder_h, attn_weights = self.decoder(dec_x, h, src)
    top_value, top_index = torch.max(decoder_out, dim=1, keepdim=True)
    preds_list.append(top_index.detach())
    probs[:, timestep] = decoder_out
preds = torch.stack(preds_list, dim=1)

return preds, probs