Backprop through indexing operation

Consider three tensors a = 5*torch.rand(5), b = torch.zeros(5,5) and c = torch.zeros(5). I just realized that b[a.long(),a.long()] = c disconnects a from the computation graph due to the casting operation (a.long()). In my case, a has history and a.requires_grad=True. How can I avoid it getting disconnected during this operation?

use torch.cat or torch.vstack

I’m not sure what you mean. I think the issue is the actual casting, i.e., 3.426 -> 3, which is not differentiable.

Here is the problem broken down into a minumum working example:

    s = 100

    data = s*torch.ones(s)
    data.requires_grad=True

    target = torch.rand(s)
    target.requires_grad=True

    net = nn.Linear(s,10)
    optimizer = torch.optim.Adam(net.parameters(),1e-5)

    for i in range(s):
        optimizer.zero_grad()
        idx = net(data).clamp(0,s-1).long()
        loss = torch.index_select(target, 0, idx).mean()
        loss.backward()
        optimizer.step()

        print(loss)

The issue is that the linear layer doesn’t receive gradients and the reason is that idx.requires_grad=False due to the long conversion, which is required by index_select.

You are right, I misunderstood the question. The matter is i find no way to avoid that. You can use a differentiable step function to push the number very closed to long version and apply the loss function with the lowest error possible.