Hi,
So i noticed an OOM issue with RNNs when i run in open-loop mode. Interestingly, for my particular use-case, I can’t set the inputs as Volatile because I will be making a backward pass.
This is a sample test case:
rnn1 = nn.GRU(63,512).cuda()
output = nn.Linear(512,63).cuda()
input_frame = Variable(torch.randn(16,1,63)).cuda()
hidden = Variable(torch.randn(1,1,512)).cuda()
for i in xrange(2000):
state,hidden = rnn(input_frame,hidden)
input_frame = output(state.squeeze()).unsqueeze(1)
And no, 2000 isn’t some random large input i’m trying, it’s very much necessary for my particular problem.
So one particular fix i noticed was using this variant instead:
rnn2 = nn.GRUCell(63,512).cuda()
output = nn.Linear(512,63).cuda()
input_frame = Variable(torch.randn(16,63)).cuda()
hidden = Variable(torch.randn(16,512)).cuda()
for i in xrange(2000):
hidden = rnn(input_frame,hidden)
input_frame = output(hidden)
I did not get OOM errors when using the GRUCell instead (I’m guessing this doesn’t use the cudnn rnn kernels?).
So in summary, I use the nn.GRU() for training (in teacher-forcing mode) and the nn.GRUCell() layer in open-loop mode (i’ll still need to make a backward pass here though).
Is there a simple way to tie weights between the GRUCell and GRU layers? Something like:
rnn2.load_state_dict(rnn1.state_dict())
(or is there another simpler fix to this issue?)
Any help is much appreciated. Thanks in advance.
Update:
So a temporary fix I found was this:
gru_state_dict = rnn1.state_dict()
for key in gru_state_dict.keys():
gru_state_dict[key.replace('_l0','')] = gru_state_dict.pop(key)
rnn2.load_state_dict(gru_state_dict)
Do you have any other elegant solutions? I guess OOM errors occur specifically in pytorch because it by-default caches the states during the forward pass and reuses it during the efficient backward pass? Would it be possible to temporarily switch-off this caching process to tradeoff memory against time?
Best,
Rithesh