Detach recurrent connections in RNN

Is it possible to detach the recurrent connections of a torch.nn.RNN when calling forward on it?

I could do this with a torch.nn.RNNCell, like this:

rnn = nn.RNNCell(x, h)
input = torch.randn(l, b, x)
hx = torch.randn(b, h)
output = list()
for i in range(l):
    hx = rnn(input[i], hx).detach()

Is there a way to do this without having to explicitly iterate over the sequence? I anticipate that doing it with an RNNCell will make the forward pass much slower.

I could also do it with torch.nn.RNN, but would still have to iterate over the sequence.


(Not sure if this is best for the autograd topic, I guess it could probably fall under NLP as well).


It feels like if you detach the recurent connection, it’s not really an rnn anymore right?

I don’t think you can modify the fused cells :confused: the whole point of them is that they are optimized because they do a fixed thing.

It would still be an RNN in the forward pass, but the gradients wouldn’t propagate recurrently.

A colleague suggested a clever way to do this (using RNN, not RNNCell):

  1. Forward pass with the RNN as usual and get the sequence of hidden states hidden (of size b x l x h)
  2. Transform hidden to be size bl x h and detach; repeat the input l times to create a new batch of size bl so your input has size bl x 1 x X and hidden is of size bl x 1 x h (basically, sequences of length 1)
  3. Call the same RNN again with these two tensors as input instead, and take the output of this call as the final output which will be used to compute the gradients.

It does require two forward passes instead of one, but is still taking advantage of the efficiency of the RNN forward pass.

I guess you will have to try it out to see which one is faster in practice.
But if you have to do two foward passes, to use the RNN, maybe the RNNCell will still be faster.