Should we .detach() predicted model outputs used as input in seq2seq model training?

When training an RNN model for seq2seq tasks, we can incorporate a decision to use either the predicted model output or the ground-truth target as input for the current batch or timestep. Should we .detach() the predicted model output in the former case?

The source of my confusion arises from an apparent discrepency in two PyTorch tutorials using a similar encoder-decoder architecture for seq2seq tasks:
https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html

decoder_input = topi.squeeze().detach()  # detach from history as input

https://pytorch.org/tutorials/beginner/chatbot_tutorial.html

decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])

A related concept is called “Scheduled Sampling” (https://arxiv.org/pdf/1506.03099.pdf), and in this approach they also do not backpropogate through the model predicted outputs.

In the disussion python - Should Decoder Prediction Be Detached in PyTorch Training? - Stack Overflow it is mentioned that we do want to .detach() due to the extra memory usage of keeping those variables in the computation graph. I am wondering however how this relates to the overall performance of the model in terms of accuracy and not efficiency.

I would like to know why we would want to .detach() the model predicted outputs from the computation graph, and in what contexts we may not want to do that. If we are using an RNN model, doesn’t this mean that we are breaking BPTT which RNN relies on to learn effectively? I am dealing with a similar situation where I need to iterate through each timestep of variable-length batched sequences during training, but I am unsure as to whether I should .detach() the model predicted outputs used as input during those training iterations where these inputs are used. I intuit that in some cases we would indeed want to backpropogate through model predicted outputs during training.

Both approaches will create a detached tensor since rewrapping a tensor also does not keep the gradient history alive.

If you are not detaching the input you will keep the gradient history alive, which will allow you to calculate the gradients w.r.t. all used parameters of all iterations, but will also increase the memory usage (since the intermediate tensors of all iterations are kept alive) and could also yield to errors.
Errors could be raised if you have already updated the used parameters via optmizer.step() will the computation graph is still attached to all iterations.
Here is a small example:

model = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 10)
)
optimizer = torch.optim.SGD(model.parameters(), lr=1.)

# works since we are creating one lage computation graph from 10 iterations
x = torch.randn(1, 10)
for _ in range(10):
    x = model(x)

x.mean().backward()
optimizer.step()

# breaks
optimizer.zero_grad()
x = torch.randn(1, 10)
for i in range(10):
    print(i)
    x = model(x)
    x.mean().backward(retain_graph=True)
    optimizer.step()

# output
# 0
# 1
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [10, 10]], which is output 0 of AsStridedBackward0, is at version 4; expected version 3 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

This error is raised since the intermediate forward activations from the first iterations are “stale” and were not created by the current parameters of the model after calling optimizer.step().

2 Likes

The predicted model outputs are essentially used as inputs to the next time steps. We ideally do not have any computation graph attached with input tensors and so, they are generally detached.

With teacher forcing, the ground truths do not come with a graph history attached to them, so this is also in sync with the ideal case.

This also explains why BPTT isn’t broken with detached tensors.

As @ptrblck has explained, it is possible to experiment with a case where the predicted outputs are used as next inputs without detaching.
I am not sure what performance difference it will bring in terms of the accuracy for any kind of auto-regressive model, but generally detaching is done to mimic the ideal case.

It would nevertheless be interesting to see if un-detached inputs have been experimentally explored with auto-regressive models.

1 Like

Thanks @ptrblck and @srishti-git1110 for the clarification! It makes sense why we should detach() those variables since we usually do not want to associate gradients with inputs. I seem to have confused the output of the GRU/LSTM with the hidden state, which in fact is backpropogated through (and in the case of per timestep training is actually the same as the output of a GRU/LSTM at each timestep).

I would like to ask a few other questions if you don’t mind:

  1. Is BPTT executed on an RNN model when we call loss.backwards(), where loss is say the accumulated loss across each timestep of a training sequence, after which we call optim.step(), which updates model parameters the gradients calculated using all prior outputs (hidden states) of the sequence?

  2. What if we call loss.backward(retain_graph=True) and optim.step() at each timestep, but detach() the hidden state (in order to avoid the error as you show)? There is no BPTT happening in this case, since we break the computation graph at each timestep. Is there then a way to update model parameters at each timestep of a sequence without needing to detach() the hidden state? It seems this issue was raised before but was left unanswered.

  3. When training a LSTM/GRU, there are normally two inputs provided to the model at each timestep: the current state features and the previous hidden state. Training seq2seq models often uses ‘teacher forcing’, where the ground truth targets are provided as the current state input at each timestep as opposed to the prior model prediction. How does BPTT differ when training a model using teacher forcing and when using its own model predictions (i.e., autoregressive training), and what specific autograd considerations should be made in the latter approach if any?