RNN Open-loop mode OOM

Hi,

So i noticed an OOM issue with RNNs when i run in open-loop mode. Interestingly, for my particular use-case, I can’t set the inputs as Volatile because I will be making a backward pass.

This is a sample test case:

rnn1 = nn.GRU(63,512).cuda()
output = nn.Linear(512,63).cuda()
input_frame = Variable(torch.randn(16,1,63)).cuda()
hidden = Variable(torch.randn(1,1,512)).cuda()
for i in xrange(2000):
    state,hidden = rnn(input_frame,hidden)
    input_frame = output(state.squeeze()).unsqueeze(1)    

And no, 2000 isn’t some random large input i’m trying, it’s very much necessary for my particular problem.

So one particular fix i noticed was using this variant instead:

rnn2 = nn.GRUCell(63,512).cuda()
output = nn.Linear(512,63).cuda()
input_frame = Variable(torch.randn(16,63)).cuda()
hidden = Variable(torch.randn(16,512)).cuda()
for i in xrange(2000):
    hidden = rnn(input_frame,hidden)
    input_frame = output(hidden)

I did not get OOM errors when using the GRUCell instead (I’m guessing this doesn’t use the cudnn rnn kernels?).

So in summary, I use the nn.GRU() for training (in teacher-forcing mode) and the nn.GRUCell() layer in open-loop mode (i’ll still need to make a backward pass here though).

Is there a simple way to tie weights between the GRUCell and GRU layers? Something like:
rnn2.load_state_dict(rnn1.state_dict())

(or is there another simpler fix to this issue?)

Any help is much appreciated. Thanks in advance.

Update:
So a temporary fix I found was this:

gru_state_dict = rnn1.state_dict()
for key in gru_state_dict.keys():
    gru_state_dict[key.replace('_l0','')] = gru_state_dict.pop(key)
rnn2.load_state_dict(gru_state_dict)

Do you have any other elegant solutions? I guess OOM errors occur specifically in pytorch because it by-default caches the states during the forward pass and reuses it during the efficient backward pass? Would it be possible to temporarily switch-off this caching process to tradeoff memory against time?

Best,
Rithesh

this is long-gone I’m sure, but wanted to say that https://github.com/pytorch/pytorch/pull/1691 will improve memory usage (especially in your case).