Replace some values of output distribution in a differentiable way

Hi all.

Is there any way that you can replace some values of output from a network in a differentiable way?

For example,

idx = [[1, 2], [0, 1]]
idx = torch.tensor(idx)
prob = [[0.4, 0.6], [0.55, 0.45]]
prob = torch.tensor(prob)
# dist = f.forward(input)
dist = [[2, 7, 2, 4.5], [8, 0.2, 2, 4.5]]
dist = torch.tensor(dist)

print('=== dist before')
print(dist)
print(F.softmax(dist, dim=-1), '\n')

for i in range(dist.size(0)):
    dist[i, idx[i]] = prob[i] * dist[i, idx[i]].sum()

print('=== dist after')
print(dist)
print(F.softmax(dist, dim=-1))

The output is,

=== dist before
tensor([[2.0000, 7.0000, 2.0000, 4.5000],
        [8.0000, 0.2000, 2.0000, 4.5000]])
tensor([[6.1502e-03, 9.1277e-01, 6.1502e-03, 7.4925e-02],
        [9.6797e-01, 3.9661e-04, 2.3994e-03, 2.9230e-02]]) 

=== dist after
tensor([[2.0000, 3.6000, 5.4000, 4.5000],
        [4.5100, 3.6900, 2.0000, 4.5000]])
tensor([[0.0208, 0.1030, 0.6230, 0.2533],
        [0.3981, 0.1753, 0.0324, 0.3942]])

As you see, I am basically changing [1, 2] index in [2, 7, 2, 4.5] following by [0.4, 0.6] ratio. So, calculation would be, 0.4 * (7 + 2) = 3.6 and 0.6 * (7 + 2) = 5.4. If I do this, the grad_fn changes from AddmmBackward to CopySlices. Does this mean that it is out of computation graph? If so, how do I handle this problem?

I mean, you can do that. What you should think about is if those gradients make sense and so. Think that when you replace the values, gradients will flow through those tensors.