Indexing a variable with a variable

How do you go about indexing a variable with another variable? For instance, it’s not clear how you could do a spatial transformer network, since the output of the transformer layer would be a Variable. Does calling idx.data work, or will that cause the graph to be disconnected?

Example

import torch
from torch.autograd import Variable
x = Variable(torch.randn(3,3))
idx = Variable(torch.LongTensor([0,1]), requires_grad=True)

# doesnt work
print(x[idx])

# works, but can you cant call .backward?
print(x[idx.data])

t = torch.sum(x[idx.data])
t.backward() # gives an error about no graph nodes require gradients
2 Likes

You can use index_select:

import torch
from torch.autograd import Variable
x = Variable(torch.randn(3,3), requires_grad=True)
idx = Variable(torch.LongTensor([0,1]))
print(x.index_select(0, idx))

Note that the index variable (idx) can’t have requires_grad set to True. The variable being indexed (x) can have requires_grad=True.

http://pytorch.org/docs/tensors.html#torch.Tensor.index_select

8 Likes

Ok, thank you for that explanation… But then it seems that a STN isn’t possible under that condition. I know this is very broad, but is there any way to index or grab values from a variable tensor with another index-like variable that does have require_grad=True besides using detach()? torch.gather doesn’t work either… Is scatter an option? It seems straight-forward in TF to so-called “differentiate the index”… is this just a limitation of define-by-run?

Here’s a fairly straight-forward STN gist I made showing how it should work except for the error in the indexing. And here is a tensorflow version of the transformer layer.

No worries though - if anyone else stumbles upon this and has insight, let me know :slight_smile:

EDIT: eh, I guess I see in the TF example above and the defomable conv pytorch example how they still propagate the gradient through other means than the index… still unclear if differentiating the index is an inherent limitation of pytorch or in general.

1 Like

Your tensorflow link is to a private repository so I can’t view it.

The index operation doesn’t have any gradients defined w.r.t. the index variable. That’s not a limitation of “define by run”, that’s a property of the operation: it has integer domain.

You need a differentiable sampling operation for spatial transformer networks. You can implement STN in PyTorch in roughly the same ways as in Tensorflow: sample floor(idx) and floor(idx) + 1 and linearly interpolate between the two. Note that the sampling doesn’t produce a meaningful gradient, it’s the interpolation that produces a useful gradient:

import torch
from torch.autograd import Variable

torch.manual_seed(0)
x = Variable(torch.randn(3,3), requires_grad=True)
idx = Variable(torch.FloatTensor([0,1]), requires_grad=True)

i0 = idx.floor().detach()
i1 = i0 + 1

y0 = x.index_select(0, i0.long())
y1 = x.index_select(0, i1.long())

Wa = (i1 - idx).unsqueeze(1).expand_as(y0)
Wb = (idx - i0).unsqueeze(1).expand_as(y1)

out = Wa * y0 + Wb * y1

print(out)
out.sum().backward()
print(idx.grad)

(You probably want to use gather instead of index_select and will need to interpolate in two dimensions instead of just one)

5 Likes

Ok, that explanation really makes it click. I got it to work with interpolation. the TF repository I linked was just using bilinear interpolation. Thanks a billion for taking the time.

Is the need for explicit index_select a “feature” or a “bug”? Should the indexing operator call index_select automatically for the Variables?

A “bug” (or at least a missing feature)

Hello everyone,

I have the problem about indexing…

In this paper section 3.3

We first select Y frames (i.e. keyframes) based on the prediction scores from the
decoder.

The decoder output is [2,320], which means non-keyframe score and key frame score of the 320 frames. We want to find a 0/1 vector according to the decoder output but the process of [2,320] -> 0/1 vector seems not differentiable…

How to implement this in pytorch?

Thank you very much.