Minimumal-ish example code:
from torch import nn
N = 5
seq_len = 3
vocab_size = 7
embedding_size = 8
embedding = nn.Embedding(vocab_size, embedding_size)
h = nn.Linear(embedding_size * seq_len, vocab_size)
encoded = torch.rand(seq_len, N, embedding_size, requires_grad=True)
out_probs = torch.zeros(seq_len, N, vocab_size)
out = torch.LongTensor(seq_len, N)
for t in range(seq_len):
out_emb = torch.zeros(seq_len, N, embedding_size)
if t > 0:
# out_t = out[:t].data
out_t = out[:t].detach()
out_emb[:t] = embedding(out_t)
out_emb = encoded + out_emb
out_emb = h(out_emb.transpose(0, 1).contiguous().view(N, -1))
out_probs[t] = out_emb
_, decoded = out_emb.max(dim=-1)
out[t] = decoded
loss = out_probs.sum()
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
But if replace
.clone().detach(), then no error.
Why? Bug in
.detach()? bug in my own code? bug in
I think you are running in the case described in the PyTorch 0.4 migration guide.
.detach() is safer in that it catches inplace modifications that will cause autograd to give wrong results when backpropagating through the graph that you have “detached from”.
Consider the following more minimal example:
a = torch.arange(5., requires_grad=True)
b = a**2
c = a.detach() # error with detach(), wrong result with .data
- This errors, and rightfully so, because it detects that
a has changed inplace and this will trip gradient calculation.
- If you comment the
c.zero_() or use
.clone().detach(), you see that
2*a just as it should be.
- If you use
.data, the connection will fully break and you’ll silently get the wrong result that
.data is intended to support some old-style updates (that could use
with torch.no_grad(): instead), e.g. in optimizers. Most likely, you should not use it in your own code unless you are exactly sure that it is the right thing to do.
Ah, ok, I think I see. It is because
.detach() doesnt implicitly create a copy of the tensor, so when I later modify that tensor, it’s updating the tensor on the upstream side of
.detach() too. By cloning first, this issue doesnt arise, and all is ok?
detach doesn’t create copies and should only prevent the gradients to be computed but shares the data. So in your case, the detach in
clone().detach() should maybe be also redundant except that you save computational resources by not updating the detached variable.