Gradient computation with index_select (for recursive neural networks)

Hi all,

I’m trying to implement a simple recursive neural network.

My cell looks something like this:

import torch
import torch.nn.functional as F
import torch.nn as nn

class ReNNCell(nn.Module):
def __init__(self, dim):
    super(ReNNCell, self).__init__()
    self.dim = dim
    self.W = nn.Linear(dim*2, dim)
    self.W_score = nn.Linear(dim, 1)

def forward(self, inputs):
    assert(inputs.size()[0] == 1), 'we expect batch size = 1'
    rep = F.relu(self.W(inputs))
    score = F.relu(self.W_score(rep))
    return rep, score

It will take two concatenated inputs and then return a new representation of these inputs and a corresponding score.
Based on this score I would like to build up a tree as follows:

tensors = [Variable(torch.randn(1,3)), Variable(torch.randn(1,3)), Variable(torch.randn(1,3)), Variable(torch.randn(1,3))]

cats = []
for i in range(1,4,2):
    cats.append(torch.cat([tensors[i-1], tensors[i]],1))

cell = ReNNCell(3,2)
outputs = [cell(c) for c in cats]
reps = [o[0] for o in outputs]
scores = [o[1] for o in outputs]
scores = torch.cat(scores, 1)
reps = torch.cat(reps, 0)
max_score, max_index = torch.max(scores, 1)
max_index = torch.squeeze(max_index, 1)
max_score = torch.squeeze(max_score, 1)

Based on this selection I will compute some dummy loss.

crit = torch.nn.MSELoss()
loss = crit(reps.index_select(0,max_index), Variable(torch.ones(1,3)))

If I call
loss.backward()
I get the following error:
…/pytorch/torch/autograd/variable.py in backward(self, gradient, retain_variables)
156 ‘or with gradient w.r.t. the variable’)
157 gradient = self.data.new().resize_as_(self.data).fill_(1)
–> 158 self._execution_engine.run_backward((self,), (gradient,), retain_variables)
159
160 def register_hook(self, hook):

RuntimeError: could not compute gradients for some functions (Threshold, Threshold)

Could anybody point me in the right direction? Is this a bug or do I perhaps have to somehow compute the gradients myself?

Hi,

It may be related to this issue.
Does replacing
max_index = torch.squeeze(max_index, 1)
by
max_index = torch.squeeze(Variable(max_index.data), 1)
solves your issue ?

Yes it does!

So index_select somehow discards the Variable state (or is it squeeze)?

No its a problem (on pytorch side, not your code) in the way backward engine computes dependencies.
Its triggered by the fact that you use the max_index from the max operation that is not differential.
You can use the workaround until this issue is solved.

1 Like

Thank you so much :slight_smile: