# Gradient computation with index_select (for recursive neural networks)

Hi all,

I’m trying to implement a simple recursive neural network.

My cell looks something like this:

``````import torch
import torch.nn.functional as F
import torch.nn as nn

class ReNNCell(nn.Module):
def __init__(self, dim):
super(ReNNCell, self).__init__()
self.dim = dim
self.W = nn.Linear(dim*2, dim)
self.W_score = nn.Linear(dim, 1)

def forward(self, inputs):
assert(inputs.size() == 1), 'we expect batch size = 1'
rep = F.relu(self.W(inputs))
score = F.relu(self.W_score(rep))
return rep, score
``````

It will take two concatenated inputs and then return a new representation of these inputs and a corresponding score.
Based on this score I would like to build up a tree as follows:

``````tensors = [Variable(torch.randn(1,3)), Variable(torch.randn(1,3)), Variable(torch.randn(1,3)), Variable(torch.randn(1,3))]

cats = []
for i in range(1,4,2):
cats.append(torch.cat([tensors[i-1], tensors[i]],1))

cell = ReNNCell(3,2)
outputs = [cell(c) for c in cats]
reps = [o for o in outputs]
scores = [o for o in outputs]
scores = torch.cat(scores, 1)
reps = torch.cat(reps, 0)
max_score, max_index = torch.max(scores, 1)
max_index = torch.squeeze(max_index, 1)
max_score = torch.squeeze(max_score, 1)
``````

Based on this selection I will compute some dummy loss.

``````crit = torch.nn.MSELoss()
loss = crit(reps.index_select(0,max_index), Variable(torch.ones(1,3)))
``````

If I call
`loss.backward()`
I get the following error:
156 ‘or with gradient w.r.t. the variable’)
159
160 def register_hook(self, hook):

``````RuntimeError: could not compute gradients for some functions (Threshold, Threshold)
``````

Could anybody point me in the right direction? Is this a bug or do I perhaps have to somehow compute the gradients myself?

Hi,

It may be related to this issue.
Does replacing
`max_index = torch.squeeze(max_index, 1)`
by
`max_index = torch.squeeze(Variable(max_index.data), 1)`
Its triggered by the fact that you use the `max_index` from the max operation that is not differential.
Thank you so much 