Indexing a variable with a variable

Your tensorflow link is to a private repository so I can’t view it.

The index operation doesn’t have any gradients defined w.r.t. the index variable. That’s not a limitation of “define by run”, that’s a property of the operation: it has integer domain.

You need a differentiable sampling operation for spatial transformer networks. You can implement STN in PyTorch in roughly the same ways as in Tensorflow: sample floor(idx) and floor(idx) + 1 and linearly interpolate between the two. Note that the sampling doesn’t produce a meaningful gradient, it’s the interpolation that produces a useful gradient:

import torch
from torch.autograd import Variable

torch.manual_seed(0)
x = Variable(torch.randn(3,3), requires_grad=True)
idx = Variable(torch.FloatTensor([0,1]), requires_grad=True)

i0 = idx.floor().detach()
i1 = i0 + 1

y0 = x.index_select(0, i0.long())
y1 = x.index_select(0, i1.long())

Wa = (i1 - idx).unsqueeze(1).expand_as(y0)
Wb = (idx - i0).unsqueeze(1).expand_as(y1)

out = Wa * y0 + Wb * y1

print(out)
out.sum().backward()
print(idx.grad)

(You probably want to use gather instead of index_select and will need to interpolate in two dimensions instead of just one)

5 Likes