.data works, .clone().detach() works; but .detach() fails. Why?

Minimumal-ish example code:

import torch
from torch import nn


N = 5
seq_len = 3
vocab_size = 7
embedding_size = 8

embedding = nn.Embedding(vocab_size, embedding_size)
h = nn.Linear(embedding_size * seq_len, vocab_size)

encoded = torch.rand(seq_len, N, embedding_size, requires_grad=True)
out_probs = torch.zeros(seq_len, N, vocab_size)
out = torch.LongTensor(seq_len, N)
for t in range(seq_len):
    out_emb = torch.zeros(seq_len, N, embedding_size)
    if t > 0:
        # out_t = out[:t].data
        out_t = out[:t].detach()
        out_emb[:t] = embedding(out_t)
    out_emb = encoded + out_emb
    out_emb = h(out_emb.transpose(0, 1).contiguous().view(N, -1))
    out_probs[t] = out_emb
    _, decoded = out_emb.max(dim=-1)
    out[t] = decoded

loss = out_probs.sum()
loss.backward()

gives error

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

But if replace .detach() with .data or .clone().detach(), then no error.

Why? Bug in .detach()? bug in my own code? bug in .data?

2 Likes

I think you are running in the case described in the PyTorch 0.4 migration guide.
.detach() is safer in that it catches inplace modifications that will cause autograd to give wrong results when backpropagating through the graph that you have “detached from”.
Consider the following more minimal example:

a = torch.arange(5., requires_grad=True)
b = a**2
c = a.detach() # error with detach(), wrong result with .data
c.zero_()
b.sum().backward()
print(a.grad)
  • This errors, and rightfully so, because it detects that a has changed inplace and this will trip gradient calculation.
  • If you comment the c.zero_() or use .clone().detach(), you see that a.grad is 2*a just as it should be.
  • If you use .data, the connection will fully break and you’ll silently get the wrong result that a.grad is 0.

.data is intended to support some old-style updates (that could use with torch.no_grad(): instead), e.g. in optimizers. Most likely, you should not use it in your own code unless you are exactly sure that it is the right thing to do.

Best regards

Thomas

3 Likes

Ah, ok, I think I see. It is because .detach() doesnt implicitly create a copy of the tensor, so when I later modify that tensor, it’s updating the tensor on the upstream side of .detach() too. By cloning first, this issue doesnt arise, and all is ok?

2 Likes

Yes, detach doesn’t create copies and should only prevent the gradients to be computed but shares the data. So in your case, the detach in clone().detach() should maybe be also redundant except that you save computational resources by not updating the detached variable.

2 Likes