Your tensorflow link is to a private repository so I can’t view it.
The index operation doesn’t have any gradients defined w.r.t. the index variable. That’s not a limitation of “define by run”, that’s a property of the operation: it has integer domain.
You need a differentiable sampling operation for spatial transformer networks. You can implement STN in PyTorch in roughly the same ways as in Tensorflow: sample floor(idx)
and floor(idx) + 1
and linearly interpolate between the two. Note that the sampling doesn’t produce a meaningful gradient, it’s the interpolation that produces a useful gradient:
import torch
from torch.autograd import Variable
torch.manual_seed(0)
x = Variable(torch.randn(3,3), requires_grad=True)
idx = Variable(torch.FloatTensor([0,1]), requires_grad=True)
i0 = idx.floor().detach()
i1 = i0 + 1
y0 = x.index_select(0, i0.long())
y1 = x.index_select(0, i1.long())
Wa = (i1 - idx).unsqueeze(1).expand_as(y0)
Wb = (idx - i0).unsqueeze(1).expand_as(y1)
out = Wa * y0 + Wb * y1
print(out)
out.sum().backward()
print(idx.grad)
(You probably want to use gather instead of index_select and will need to interpolate in two dimensions instead of just one)