Updating hidden states without in place operations

Hello,

I am implementing a GRU rnn using grucells. I have N layers, so I store my hidden states as a N x hidden_dimension variable.

The thing is that after the output of my GRUcell, I want to update the values of the hidden state with the output of the GRUCell, but this makes backward() break.

So for example:

for idxLayer in range(self.n_layers):
    currentCell = self.grucells[idxLayer]
    output = currentCell(output, hiddenStates[idxLayer, :].view(1,-1))
    hiddenStates[idxLayer,:] = output #this operation makes the backward() operation fail
hiddenStates.sum().backward() #Runtime error

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

How can I solve the problem? Thanks!

But why don’t you use nn.GRU instead of N nn.GRUCell ? If I understand correctly, what you want to do is exactly what GRU does (stacking N GRUCells for a N-layer GRU network)

Because I need to modify the outputs of each GRUCell in a certain way before passing them as inputs to the next GRUCell :slight_smile:

Ok, then it makes sense.

I would suggest first to initialize your hidden with the good shape (1, hidden_size) so you can directly update them:

hiddenStates = []
hiddenStates.append(input)
# init hidden:
for idxLayer in range(self.n_layers):
    hiddenStates.append(Variable(torch.zeros(1,hidden_size))
# then hiddenStates is a list of Variables of size n_layer+1
for idxLayer in range(self.n_layers):
    currentCell = self.grucells[idxLayer+1]
    hiddenStates[idxLayer+1] = currentCell(hiddenStates[idxLayer], hiddenStates[idxLayer+1])
z = 0
for hidden in hiddenStates[1:]:
    z += hidden.sum()
z.backward()

That way, you don’t do any inplace operation :wink:

1 Like

Thank you! This solved the problem :smiley:

So in general I can write something like

for idx in range(100):
   var1 = network(var1)

but if I try to assign only part of the variable (making it an inplace operation) then it doesn’t work anymore. Right?

Yes, all operations containing indexes in brackets would not be “backwardable”, you have to call functions that takes the whole variable as argument

2 Likes