Hi all.
Is there any way that you can replace some values of output from a network in a differentiable way?
For example,
idx = [[1, 2], [0, 1]]
idx = torch.tensor(idx)
prob = [[0.4, 0.6], [0.55, 0.45]]
prob = torch.tensor(prob)
# dist = f.forward(input)
dist = [[2, 7, 2, 4.5], [8, 0.2, 2, 4.5]]
dist = torch.tensor(dist)
print('=== dist before')
print(dist)
print(F.softmax(dist, dim=-1), '\n')
for i in range(dist.size(0)):
dist[i, idx[i]] = prob[i] * dist[i, idx[i]].sum()
print('=== dist after')
print(dist)
print(F.softmax(dist, dim=-1))
The output is,
=== dist before
tensor([[2.0000, 7.0000, 2.0000, 4.5000],
[8.0000, 0.2000, 2.0000, 4.5000]])
tensor([[6.1502e-03, 9.1277e-01, 6.1502e-03, 7.4925e-02],
[9.6797e-01, 3.9661e-04, 2.3994e-03, 2.9230e-02]])
=== dist after
tensor([[2.0000, 3.6000, 5.4000, 4.5000],
[4.5100, 3.6900, 2.0000, 4.5000]])
tensor([[0.0208, 0.1030, 0.6230, 0.2533],
[0.3981, 0.1753, 0.0324, 0.3942]])
As you see, I am basically changing [1, 2]
index in [2, 7, 2, 4.5]
following by [0.4, 0.6]
ratio. So, calculation would be, 0.4 * (7 + 2) = 3.6
and 0.6 * (7 + 2) = 5.4
. If I do this, the grad_fn
changes from AddmmBackward
to CopySlices
. Does this mean that it is out of computation graph? If so, how do I handle this problem?