Is it possible to detach the recurrent connections of a
torch.nn.RNN when calling
forward on it?
I could do this with a
torch.nn.RNNCell, like this:
rnn = nn.RNNCell(x, h)
input = torch.randn(l, b, x)
hx = torch.randn(b, h)
output = list()
for i in range(l):
hx = rnn(input[i], hx).detach()
Is there a way to do this without having to explicitly iterate over the sequence? I anticipate that doing it with an
RNNCell will make the forward pass much slower.
I could also do it with
torch.nn.RNN, but would still have to iterate over the sequence.
(Not sure if this is best for the autograd topic, I guess it could probably fall under NLP as well).
It feels like if you detach the recurent connection, it’s not really an rnn anymore right?
I don’t think you can modify the fused cells the whole point of them is that they are optimized because they do a fixed thing.
It would still be an RNN in the forward pass, but the gradients wouldn’t propagate recurrently.
A colleague suggested a clever way to do this (using
- Forward pass with the RNN as usual and get the sequence of hidden states
hidden (of size
b x l x h)
hidden to be size
bl x h and detach; repeat the input
l times to create a new batch of size
bl so your input has size
bl x 1 x X and
hidden is of size
bl x 1 x h (basically, sequences of length 1)
- Call the same RNN again with these two tensors as input instead, and take the output of this call as the final output which will be used to compute the gradients.
It does require two forward passes instead of one, but is still taking advantage of the efficiency of the
RNN forward pass.
I guess you will have to try it out to see which one is faster in practice.
But if you have to do two foward passes, to use the RNN, maybe the RNNCell will still be faster.