Consider three tensors a = 5*torch.rand(5)
, b = torch.zeros(5,5)
and c = torch.zeros(5)
. I just realized that b[a.long(),a.long()] = c
disconnects a
from the computation graph due to the casting operation (a.long()
). In my case, a
has history and a.requires_grad=True
. How can I avoid it getting disconnected during this operation?
use torch.cat or torch.vstack
I’m not sure what you mean. I think the issue is the actual casting, i.e., 3.426 -> 3, which is not differentiable.
Here is the problem broken down into a minumum working example:
s = 100
data = s*torch.ones(s)
data.requires_grad=True
target = torch.rand(s)
target.requires_grad=True
net = nn.Linear(s,10)
optimizer = torch.optim.Adam(net.parameters(),1e-5)
for i in range(s):
optimizer.zero_grad()
idx = net(data).clamp(0,s-1).long()
loss = torch.index_select(target, 0, idx).mean()
loss.backward()
optimizer.step()
print(loss)
The issue is that the linear layer doesn’t receive gradients and the reason is that idx.requires_grad=False
due to the long
conversion, which is required by index_select
.
You are right, I misunderstood the question. The matter is i find no way to avoid that. You can use a differentiable step function to push the number very closed to long version and apply the loss function with the lowest error possible.